103 lines
2.5 KiB
Go
103 lines
2.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/calendarapi/internal/auth"
|
|
"github.com/calendarapi/internal/models"
|
|
"github.com/calendarapi/internal/repository"
|
|
"github.com/calendarapi/internal/utils"
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
type AuthMiddleware struct {
|
|
jwt *auth.JWTManager
|
|
queries *repository.Queries
|
|
}
|
|
|
|
func NewAuthMiddleware(jwt *auth.JWTManager, queries *repository.Queries) *AuthMiddleware {
|
|
return &AuthMiddleware{jwt: jwt, queries: queries}
|
|
}
|
|
|
|
func (m *AuthMiddleware) Authenticate(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
if token := extractBearerToken(r); token != "" {
|
|
claims, err := m.jwt.ValidateToken(token)
|
|
if err != nil {
|
|
utils.WriteError(w, models.ErrAuthInvalid)
|
|
return
|
|
}
|
|
if _, err := m.queries.GetUserByID(ctx, utils.ToPgUUID(claims.UserID)); err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
utils.WriteError(w, models.ErrAuthInvalid)
|
|
return
|
|
}
|
|
utils.WriteError(w, models.ErrInternal)
|
|
return
|
|
}
|
|
ctx = SetUserID(ctx, claims.UserID)
|
|
ctx = SetAuthMethod(ctx, "jwt")
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
|
hash := SHA256Hash(apiKey)
|
|
key, err := m.queries.GetAPIKeyByHash(ctx, hash)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
utils.WriteError(w, models.ErrAuthInvalid)
|
|
return
|
|
}
|
|
utils.WriteError(w, models.ErrInternal)
|
|
return
|
|
}
|
|
|
|
var scopes Scopes
|
|
if err := json.Unmarshal(key.Scopes, &scopes); err != nil {
|
|
utils.WriteError(w, models.ErrInternal)
|
|
return
|
|
}
|
|
|
|
ctx = SetUserID(ctx, utils.FromPgUUID(key.UserID))
|
|
ctx = SetAuthMethod(ctx, "api_key")
|
|
ctx = SetScopes(ctx, scopes)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
utils.WriteError(w, models.ErrAuthRequired)
|
|
})
|
|
}
|
|
|
|
func RequireScope(resource, action string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !HasScope(r.Context(), resource, action) {
|
|
utils.WriteError(w, models.ErrForbidden)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(r.Context()))
|
|
})
|
|
}
|
|
}
|
|
|
|
func extractBearerToken(r *http.Request) string {
|
|
h := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(h, "Bearer ") {
|
|
return strings.TrimPrefix(h, "Bearer ")
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func SHA256Hash(s string) string {
|
|
h := sha256.Sum256([]byte(s))
|
|
return hex.EncodeToString(h[:])
|
|
}
|