first commit
This commit is contained in:
@@ -0,0 +1,119 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type Middleware struct {
|
||||
jwtSecret []byte
|
||||
limit int
|
||||
mu sync.Mutex
|
||||
buckets map[string]*bucket
|
||||
requests *prometheus.CounterVec
|
||||
duration *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
count int
|
||||
reset time.Time
|
||||
}
|
||||
|
||||
func NewMiddleware(jwtSecret string, rpm int, requests *prometheus.CounterVec, duration *prometheus.HistogramVec) *Middleware {
|
||||
return &Middleware{jwtSecret: []byte(jwtSecret), limit: rpm, buckets: map[string]*bucket{}, requests: requests, duration: duration}
|
||||
}
|
||||
|
||||
func (m *Middleware) Chain(next http.Handler) http.Handler {
|
||||
return m.Metrics(m.RateLimit(next))
|
||||
}
|
||||
|
||||
func (m *Middleware) Auth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
WriteError(w, http.StatusUnauthorized, "missing token")
|
||||
return
|
||||
}
|
||||
token, err := jwt.Parse(strings.TrimPrefix(header, "Bearer "), func(t *jwt.Token) (any, error) {
|
||||
return m.jwtSecret, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
WriteError(w, http.StatusUnauthorized, "invalid token")
|
||||
return
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
WriteError(w, http.StatusUnauthorized, "invalid claims")
|
||||
return
|
||||
}
|
||||
sub, err := claims.GetSubject()
|
||||
if err != nil {
|
||||
if f, ok := claims["sub"].(float64); ok {
|
||||
sub = strconv.FormatInt(int64(f), 10)
|
||||
}
|
||||
}
|
||||
id, err := strconv.ParseInt(sub, 10, 64)
|
||||
if err != nil {
|
||||
WriteError(w, http.StatusUnauthorized, "invalid subject")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(WithUserID(r.Context(), id)))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) RateLimit(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := r.RemoteAddr
|
||||
if id, ok := UserID(r.Context()); ok {
|
||||
key = strconv.FormatInt(id, 10)
|
||||
}
|
||||
if !m.allow(key) {
|
||||
WriteError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) Metrics(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rec := &statusRecorder{ResponseWriter: w, status: 200}
|
||||
next.ServeHTTP(rec, r)
|
||||
code := strconv.Itoa(rec.status)
|
||||
m.requests.WithLabelValues(r.Method, r.URL.Path, code).Inc()
|
||||
m.duration.WithLabelValues(r.Method, r.URL.Path).Observe(time.Since(start).Seconds())
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) allow(key string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
now := time.Now()
|
||||
b := m.buckets[key]
|
||||
if b == nil || now.After(b.reset) {
|
||||
m.buckets[key] = &bucket{count: 1, reset: now.Add(time.Minute)}
|
||||
return true
|
||||
}
|
||||
if b.count >= m.limit {
|
||||
return false
|
||||
}
|
||||
b.count++
|
||||
return true
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (r *statusRecorder) WriteHeader(code int) {
|
||||
r.status = code
|
||||
r.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
Reference in New Issue
Block a user