package httpx import ( "net/http" "strconv" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/prometheus/client_golang/prometheus" ) type Middleware struct { jwtSecret []byte limit int mu sync.Mutex buckets map[string]*bucket requests *prometheus.CounterVec duration *prometheus.HistogramVec } type bucket struct { count int reset time.Time } func NewMiddleware(jwtSecret string, rpm int, requests *prometheus.CounterVec, duration *prometheus.HistogramVec) *Middleware { return &Middleware{jwtSecret: []byte(jwtSecret), limit: rpm, buckets: map[string]*bucket{}, requests: requests, duration: duration} } func (m *Middleware) Chain(next http.Handler) http.Handler { return m.Metrics(m.RateLimit(next)) } func (m *Middleware) Auth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { header := r.Header.Get("Authorization") if !strings.HasPrefix(header, "Bearer ") { WriteError(w, http.StatusUnauthorized, "missing token") return } token, err := jwt.Parse(strings.TrimPrefix(header, "Bearer "), func(t *jwt.Token) (any, error) { return m.jwtSecret, nil }) if err != nil || !token.Valid { WriteError(w, http.StatusUnauthorized, "invalid token") return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { WriteError(w, http.StatusUnauthorized, "invalid claims") return } sub, err := claims.GetSubject() if err != nil { if f, ok := claims["sub"].(float64); ok { sub = strconv.FormatInt(int64(f), 10) } } id, err := strconv.ParseInt(sub, 10, 64) if err != nil { WriteError(w, http.StatusUnauthorized, "invalid subject") return } next.ServeHTTP(w, r.WithContext(WithUserID(r.Context(), id))) }) } func (m *Middleware) RateLimit(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := r.RemoteAddr if id, ok := UserID(r.Context()); ok { key = strconv.FormatInt(id, 10) } if !m.allow(key) { WriteError(w, http.StatusTooManyRequests, "rate limit exceeded") return } next.ServeHTTP(w, r) }) } func (m *Middleware) Metrics(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rec := &statusRecorder{ResponseWriter: w, status: 200} next.ServeHTTP(rec, r) code := strconv.Itoa(rec.status) m.requests.WithLabelValues(r.Method, r.URL.Path, code).Inc() m.duration.WithLabelValues(r.Method, r.URL.Path).Observe(time.Since(start).Seconds()) }) } func (m *Middleware) allow(key string) bool { m.mu.Lock() defer m.mu.Unlock() now := time.Now() b := m.buckets[key] if b == nil || now.After(b.reset) { m.buckets[key] = &bucket{count: 1, reset: now.Add(time.Minute)} return true } if b.count >= m.limit { return false } b.count++ return true } type statusRecorder struct { http.ResponseWriter status int } func (r *statusRecorder) WriteHeader(code int) { r.status = code r.ResponseWriter.WriteHeader(code) }