120 lines
3.0 KiB
Go
120 lines
3.0 KiB
Go
package httpx
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
)
|
|
|
|
type Middleware struct {
|
|
jwtSecret []byte
|
|
limit int
|
|
mu sync.Mutex
|
|
buckets map[string]*bucket
|
|
requests *prometheus.CounterVec
|
|
duration *prometheus.HistogramVec
|
|
}
|
|
|
|
type bucket struct {
|
|
count int
|
|
reset time.Time
|
|
}
|
|
|
|
func NewMiddleware(jwtSecret string, rpm int, requests *prometheus.CounterVec, duration *prometheus.HistogramVec) *Middleware {
|
|
return &Middleware{jwtSecret: []byte(jwtSecret), limit: rpm, buckets: map[string]*bucket{}, requests: requests, duration: duration}
|
|
}
|
|
|
|
func (m *Middleware) Chain(next http.Handler) http.Handler {
|
|
return m.Metrics(m.RateLimit(next))
|
|
}
|
|
|
|
func (m *Middleware) Auth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
header := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(header, "Bearer ") {
|
|
WriteError(w, http.StatusUnauthorized, "missing token")
|
|
return
|
|
}
|
|
token, err := jwt.Parse(strings.TrimPrefix(header, "Bearer "), func(t *jwt.Token) (any, error) {
|
|
return m.jwtSecret, nil
|
|
})
|
|
if err != nil || !token.Valid {
|
|
WriteError(w, http.StatusUnauthorized, "invalid token")
|
|
return
|
|
}
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
WriteError(w, http.StatusUnauthorized, "invalid claims")
|
|
return
|
|
}
|
|
sub, err := claims.GetSubject()
|
|
if err != nil {
|
|
if f, ok := claims["sub"].(float64); ok {
|
|
sub = strconv.FormatInt(int64(f), 10)
|
|
}
|
|
}
|
|
id, err := strconv.ParseInt(sub, 10, 64)
|
|
if err != nil {
|
|
WriteError(w, http.StatusUnauthorized, "invalid subject")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(WithUserID(r.Context(), id)))
|
|
})
|
|
}
|
|
|
|
func (m *Middleware) RateLimit(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
key := r.RemoteAddr
|
|
if id, ok := UserID(r.Context()); ok {
|
|
key = strconv.FormatInt(id, 10)
|
|
}
|
|
if !m.allow(key) {
|
|
WriteError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (m *Middleware) Metrics(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
rec := &statusRecorder{ResponseWriter: w, status: 200}
|
|
next.ServeHTTP(rec, r)
|
|
code := strconv.Itoa(rec.status)
|
|
m.requests.WithLabelValues(r.Method, r.URL.Path, code).Inc()
|
|
m.duration.WithLabelValues(r.Method, r.URL.Path).Observe(time.Since(start).Seconds())
|
|
})
|
|
}
|
|
|
|
func (m *Middleware) allow(key string) bool {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
now := time.Now()
|
|
b := m.buckets[key]
|
|
if b == nil || now.After(b.reset) {
|
|
m.buckets[key] = &bucket{count: 1, reset: now.Add(time.Minute)}
|
|
return true
|
|
}
|
|
if b.count >= m.limit {
|
|
return false
|
|
}
|
|
b.count++
|
|
return true
|
|
}
|
|
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (r *statusRecorder) WriteHeader(code int) {
|
|
r.status = code
|
|
r.ResponseWriter.WriteHeader(code)
|
|
}
|