first commit
Made-with: Cursor
This commit is contained in:
102
internal/middleware/auth.go
Normal file
102
internal/middleware/auth.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/calendarapi/internal/auth"
|
||||
"github.com/calendarapi/internal/models"
|
||||
"github.com/calendarapi/internal/repository"
|
||||
"github.com/calendarapi/internal/utils"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type AuthMiddleware struct {
|
||||
jwt *auth.JWTManager
|
||||
queries *repository.Queries
|
||||
}
|
||||
|
||||
func NewAuthMiddleware(jwt *auth.JWTManager, queries *repository.Queries) *AuthMiddleware {
|
||||
return &AuthMiddleware{jwt: jwt, queries: queries}
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Authenticate(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if token := extractBearerToken(r); token != "" {
|
||||
claims, err := m.jwt.ValidateToken(token)
|
||||
if err != nil {
|
||||
utils.WriteError(w, models.ErrAuthInvalid)
|
||||
return
|
||||
}
|
||||
if _, err := m.queries.GetUserByID(ctx, utils.ToPgUUID(claims.UserID)); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
utils.WriteError(w, models.ErrAuthInvalid)
|
||||
return
|
||||
}
|
||||
utils.WriteError(w, models.ErrInternal)
|
||||
return
|
||||
}
|
||||
ctx = SetUserID(ctx, claims.UserID)
|
||||
ctx = SetAuthMethod(ctx, "jwt")
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
||||
hash := SHA256Hash(apiKey)
|
||||
key, err := m.queries.GetAPIKeyByHash(ctx, hash)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
utils.WriteError(w, models.ErrAuthInvalid)
|
||||
return
|
||||
}
|
||||
utils.WriteError(w, models.ErrInternal)
|
||||
return
|
||||
}
|
||||
|
||||
var scopes Scopes
|
||||
if err := json.Unmarshal(key.Scopes, &scopes); err != nil {
|
||||
utils.WriteError(w, models.ErrInternal)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = SetUserID(ctx, utils.FromPgUUID(key.UserID))
|
||||
ctx = SetAuthMethod(ctx, "api_key")
|
||||
ctx = SetScopes(ctx, scopes)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
utils.WriteError(w, models.ErrAuthRequired)
|
||||
})
|
||||
}
|
||||
|
||||
func RequireScope(resource, action string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !HasScope(r.Context(), resource, action) {
|
||||
utils.WriteError(w, models.ErrForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(r.Context()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func extractBearerToken(r *http.Request) string {
|
||||
h := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(h, "Bearer ") {
|
||||
return strings.TrimPrefix(h, "Bearer ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SHA256Hash(s string) string {
|
||||
h := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
64
internal/middleware/context.go
Normal file
64
internal/middleware/context.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
userIDKey contextKey = "user_id"
|
||||
authMethodKey contextKey = "auth_method"
|
||||
scopesKey contextKey = "scopes"
|
||||
)
|
||||
|
||||
type Scopes map[string][]string
|
||||
|
||||
func SetUserID(ctx context.Context, id uuid.UUID) context.Context {
|
||||
return context.WithValue(ctx, userIDKey, id)
|
||||
}
|
||||
|
||||
func GetUserID(ctx context.Context) (uuid.UUID, bool) {
|
||||
id, ok := ctx.Value(userIDKey).(uuid.UUID)
|
||||
return id, ok
|
||||
}
|
||||
|
||||
func SetAuthMethod(ctx context.Context, method string) context.Context {
|
||||
return context.WithValue(ctx, authMethodKey, method)
|
||||
}
|
||||
|
||||
func GetAuthMethod(ctx context.Context) string {
|
||||
m, _ := ctx.Value(authMethodKey).(string)
|
||||
return m
|
||||
}
|
||||
|
||||
func SetScopes(ctx context.Context, scopes Scopes) context.Context {
|
||||
return context.WithValue(ctx, scopesKey, scopes)
|
||||
}
|
||||
|
||||
func GetScopes(ctx context.Context) Scopes {
|
||||
s, _ := ctx.Value(scopesKey).(Scopes)
|
||||
return s
|
||||
}
|
||||
|
||||
func HasScope(ctx context.Context, resource, action string) bool {
|
||||
if GetAuthMethod(ctx) == "jwt" {
|
||||
return true
|
||||
}
|
||||
scopes := GetScopes(ctx)
|
||||
if scopes == nil {
|
||||
return false
|
||||
}
|
||||
actions, ok := scopes[resource]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, a := range actions {
|
||||
if a == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
86
internal/middleware/ratelimit.go
Normal file
86
internal/middleware/ratelimit.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/calendarapi/internal/models"
|
||||
"github.com/calendarapi/internal/utils"
|
||||
)
|
||||
|
||||
type visitor struct {
|
||||
tokens float64
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
visitors map[string]*visitor
|
||||
rate float64
|
||||
burst float64
|
||||
}
|
||||
|
||||
func NewRateLimiter(ratePerSecond float64, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
visitors: make(map[string]*visitor),
|
||||
rate: ratePerSecond,
|
||||
burst: float64(burst),
|
||||
}
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := r.RemoteAddr
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
ip = fwd
|
||||
}
|
||||
|
||||
if !rl.allow(ip) {
|
||||
utils.WriteError(w, models.ErrRateLimited)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
v, exists := rl.visitors[key]
|
||||
now := time.Now()
|
||||
|
||||
if !exists {
|
||||
rl.visitors[key] = &visitor{tokens: rl.burst - 1, lastSeen: now}
|
||||
return true
|
||||
}
|
||||
|
||||
elapsed := now.Sub(v.lastSeen).Seconds()
|
||||
v.tokens += elapsed * rl.rate
|
||||
if v.tokens > rl.burst {
|
||||
v.tokens = rl.burst
|
||||
}
|
||||
v.lastSeen = now
|
||||
|
||||
if v.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
v.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
for {
|
||||
time.Sleep(5 * time.Minute)
|
||||
rl.mu.Lock()
|
||||
for key, v := range rl.visitors {
|
||||
if time.Since(v.lastSeen) > 10*time.Minute {
|
||||
delete(rl.visitors, key)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user