87 lines
1.5 KiB
Go
87 lines
1.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/calendarapi/internal/models"
|
|
"github.com/calendarapi/internal/utils"
|
|
)
|
|
|
|
type visitor struct {
|
|
tokens float64
|
|
lastSeen time.Time
|
|
}
|
|
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
visitors map[string]*visitor
|
|
rate float64
|
|
burst float64
|
|
}
|
|
|
|
func NewRateLimiter(ratePerSecond float64, burst int) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
visitors: make(map[string]*visitor),
|
|
rate: ratePerSecond,
|
|
burst: float64(burst),
|
|
}
|
|
go rl.cleanup()
|
|
return rl
|
|
}
|
|
|
|
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := r.RemoteAddr
|
|
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
|
ip = fwd
|
|
}
|
|
|
|
if !rl.allow(ip) {
|
|
utils.WriteError(w, models.ErrRateLimited)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (rl *RateLimiter) allow(key string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
v, exists := rl.visitors[key]
|
|
now := time.Now()
|
|
|
|
if !exists {
|
|
rl.visitors[key] = &visitor{tokens: rl.burst - 1, lastSeen: now}
|
|
return true
|
|
}
|
|
|
|
elapsed := now.Sub(v.lastSeen).Seconds()
|
|
v.tokens += elapsed * rl.rate
|
|
if v.tokens > rl.burst {
|
|
v.tokens = rl.burst
|
|
}
|
|
v.lastSeen = now
|
|
|
|
if v.tokens < 1 {
|
|
return false
|
|
}
|
|
v.tokens--
|
|
return true
|
|
}
|
|
|
|
func (rl *RateLimiter) cleanup() {
|
|
for {
|
|
time.Sleep(5 * time.Minute)
|
|
rl.mu.Lock()
|
|
for key, v := range rl.visitors {
|
|
if time.Since(v.lastSeen) > 10*time.Minute {
|
|
delete(rl.visitors, key)
|
|
}
|
|
}
|
|
rl.mu.Unlock()
|
|
}
|
|
}
|