package middleware import ( "crypto/sha256" "encoding/hex" "encoding/json" "net/http" "strings" "github.com/calendarapi/internal/auth" "github.com/calendarapi/internal/models" "github.com/calendarapi/internal/repository" "github.com/calendarapi/internal/utils" "github.com/jackc/pgx/v5" ) type AuthMiddleware struct { jwt *auth.JWTManager queries *repository.Queries } func NewAuthMiddleware(jwt *auth.JWTManager, queries *repository.Queries) *AuthMiddleware { return &AuthMiddleware{jwt: jwt, queries: queries} } func (m *AuthMiddleware) Authenticate(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if token := extractBearerToken(r); token != "" { claims, err := m.jwt.ValidateToken(token) if err != nil { utils.WriteError(w, models.ErrAuthInvalid) return } if _, err := m.queries.GetUserByID(ctx, utils.ToPgUUID(claims.UserID)); err != nil { if err == pgx.ErrNoRows { utils.WriteError(w, models.ErrAuthInvalid) return } utils.WriteError(w, models.ErrInternal) return } ctx = SetUserID(ctx, claims.UserID) ctx = SetAuthMethod(ctx, "jwt") next.ServeHTTP(w, r.WithContext(ctx)) return } if apiKey := r.Header.Get("X-API-Key"); apiKey != "" { hash := SHA256Hash(apiKey) key, err := m.queries.GetAPIKeyByHash(ctx, hash) if err != nil { if err == pgx.ErrNoRows { utils.WriteError(w, models.ErrAuthInvalid) return } utils.WriteError(w, models.ErrInternal) return } var scopes Scopes if err := json.Unmarshal(key.Scopes, &scopes); err != nil { utils.WriteError(w, models.ErrInternal) return } ctx = SetUserID(ctx, utils.FromPgUUID(key.UserID)) ctx = SetAuthMethod(ctx, "api_key") ctx = SetScopes(ctx, scopes) next.ServeHTTP(w, r.WithContext(ctx)) return } utils.WriteError(w, models.ErrAuthRequired) }) } func RequireScope(resource, action string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !HasScope(r.Context(), resource, action) { utils.WriteError(w, models.ErrForbidden) return } next.ServeHTTP(w, r.WithContext(r.Context())) }) } } func extractBearerToken(r *http.Request) string { h := r.Header.Get("Authorization") if strings.HasPrefix(h, "Bearer ") { return strings.TrimPrefix(h, "Bearer ") } return "" } func SHA256Hash(s string) string { h := sha256.Sum256([]byte(s)) return hex.EncodeToString(h[:]) }