first commit
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
package httpx
|
||||
|
||||
import "context"
|
||||
|
||||
type userIDKey struct{}
|
||||
|
||||
func WithUserID(ctx context.Context, userID int64) context.Context {
|
||||
return context.WithValue(ctx, userIDKey{}, userID)
|
||||
}
|
||||
|
||||
func UserID(ctx context.Context) (int64, bool) {
|
||||
v, ok := ctx.Value(userIDKey{}).(int64)
|
||||
return v, ok
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type Middleware struct {
|
||||
jwtSecret []byte
|
||||
limit int
|
||||
mu sync.Mutex
|
||||
buckets map[string]*bucket
|
||||
requests *prometheus.CounterVec
|
||||
duration *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
count int
|
||||
reset time.Time
|
||||
}
|
||||
|
||||
func NewMiddleware(jwtSecret string, rpm int, requests *prometheus.CounterVec, duration *prometheus.HistogramVec) *Middleware {
|
||||
return &Middleware{jwtSecret: []byte(jwtSecret), limit: rpm, buckets: map[string]*bucket{}, requests: requests, duration: duration}
|
||||
}
|
||||
|
||||
func (m *Middleware) Chain(next http.Handler) http.Handler {
|
||||
return m.Metrics(m.RateLimit(next))
|
||||
}
|
||||
|
||||
func (m *Middleware) Auth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
WriteError(w, http.StatusUnauthorized, "missing token")
|
||||
return
|
||||
}
|
||||
token, err := jwt.Parse(strings.TrimPrefix(header, "Bearer "), func(t *jwt.Token) (any, error) {
|
||||
return m.jwtSecret, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
WriteError(w, http.StatusUnauthorized, "invalid token")
|
||||
return
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
WriteError(w, http.StatusUnauthorized, "invalid claims")
|
||||
return
|
||||
}
|
||||
sub, err := claims.GetSubject()
|
||||
if err != nil {
|
||||
if f, ok := claims["sub"].(float64); ok {
|
||||
sub = strconv.FormatInt(int64(f), 10)
|
||||
}
|
||||
}
|
||||
id, err := strconv.ParseInt(sub, 10, 64)
|
||||
if err != nil {
|
||||
WriteError(w, http.StatusUnauthorized, "invalid subject")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(WithUserID(r.Context(), id)))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) RateLimit(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := r.RemoteAddr
|
||||
if id, ok := UserID(r.Context()); ok {
|
||||
key = strconv.FormatInt(id, 10)
|
||||
}
|
||||
if !m.allow(key) {
|
||||
WriteError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) Metrics(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rec := &statusRecorder{ResponseWriter: w, status: 200}
|
||||
next.ServeHTTP(rec, r)
|
||||
code := strconv.Itoa(rec.status)
|
||||
m.requests.WithLabelValues(r.Method, r.URL.Path, code).Inc()
|
||||
m.duration.WithLabelValues(r.Method, r.URL.Path).Observe(time.Since(start).Seconds())
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) allow(key string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
now := time.Now()
|
||||
b := m.buckets[key]
|
||||
if b == nil || now.After(b.reset) {
|
||||
m.buckets[key] = &bucket{count: 1, reset: now.Add(time.Minute)}
|
||||
return true
|
||||
}
|
||||
if b.count >= m.limit {
|
||||
return false
|
||||
}
|
||||
b.count++
|
||||
return true
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (r *statusRecorder) WriteHeader(code int) {
|
||||
r.status = code
|
||||
r.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"secunda-test/internal/service"
|
||||
)
|
||||
|
||||
func WriteJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
func WriteError(w http.ResponseWriter, status int, message string) {
|
||||
WriteJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
func StatusFromError(err error) int {
|
||||
switch {
|
||||
case errors.Is(err, service.ErrUnauthorized):
|
||||
return http.StatusUnauthorized
|
||||
case errors.Is(err, service.ErrForbidden):
|
||||
return http.StatusForbidden
|
||||
case errors.Is(err, service.ErrNotFound):
|
||||
return http.StatusNotFound
|
||||
case errors.Is(err, service.ErrBadRequest):
|
||||
return http.StatusBadRequest
|
||||
case errors.Is(err, service.ErrConflict):
|
||||
return http.StatusConflict
|
||||
default:
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user