257 lines
6.8 KiB
Go
257 lines
6.8 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"time"
|
|
|
|
"github.com/calendarapi/internal/auth"
|
|
"github.com/calendarapi/internal/models"
|
|
"github.com/calendarapi/internal/repository"
|
|
"github.com/calendarapi/internal/utils"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type AuthService struct {
|
|
pool *pgxpool.Pool
|
|
queries *repository.Queries
|
|
jwt *auth.JWTManager
|
|
audit *AuditService
|
|
}
|
|
|
|
func NewAuthService(pool *pgxpool.Pool, queries *repository.Queries, jwt *auth.JWTManager, audit *AuditService) *AuthService {
|
|
return &AuthService{pool: pool, queries: queries, jwt: jwt, audit: audit}
|
|
}
|
|
|
|
func (s *AuthService) Register(ctx context.Context, email, password, timezone string) (*models.AuthTokens, error) {
|
|
email = utils.NormalizeEmail(email)
|
|
if err := utils.ValidateEmail(email); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := utils.ValidatePassword(password); err != nil {
|
|
return nil, err
|
|
}
|
|
if timezone == "" {
|
|
timezone = "UTC"
|
|
}
|
|
if err := utils.ValidateTimezone(timezone); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err := s.queries.GetUserByEmail(ctx, email)
|
|
if err == nil {
|
|
return nil, models.NewConflictError("email already registered")
|
|
}
|
|
if err != pgx.ErrNoRows {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
qtx := s.queries.WithTx(tx)
|
|
|
|
userID := uuid.New()
|
|
dbUser, err := qtx.CreateUser(ctx, repository.CreateUserParams{
|
|
ID: utils.ToPgUUID(userID),
|
|
Email: email,
|
|
PasswordHash: string(hash),
|
|
Timezone: timezone,
|
|
})
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
calID := uuid.New()
|
|
_, err = qtx.CreateCalendar(ctx, repository.CreateCalendarParams{
|
|
ID: utils.ToPgUUID(calID),
|
|
OwnerID: utils.ToPgUUID(userID),
|
|
Name: "My Calendar",
|
|
Color: "#3B82F6",
|
|
IsPublic: false,
|
|
})
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
err = qtx.UpsertCalendarMember(ctx, repository.UpsertCalendarMemberParams{
|
|
CalendarID: utils.ToPgUUID(calID),
|
|
UserID: utils.ToPgUUID(userID),
|
|
Role: "owner",
|
|
})
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
accessToken, err := s.jwt.GenerateAccessToken(userID)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
refreshToken, err := s.jwt.GenerateRefreshToken(userID)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
rtHash := hashToken(refreshToken)
|
|
_, err = s.queries.CreateRefreshToken(ctx, repository.CreateRefreshTokenParams{
|
|
ID: utils.ToPgUUID(uuid.New()),
|
|
UserID: utils.ToPgUUID(userID),
|
|
TokenHash: rtHash,
|
|
ExpiresAt: utils.ToPgTimestamptz(time.Now().Add(auth.RefreshTokenDuration)),
|
|
})
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
user := userFromCreateRow(dbUser)
|
|
return &models.AuthTokens{User: user, AccessToken: accessToken, RefreshToken: refreshToken}, nil
|
|
}
|
|
|
|
func (s *AuthService) Login(ctx context.Context, email, password string) (*models.AuthTokens, error) {
|
|
email = utils.NormalizeEmail(email)
|
|
|
|
dbUser, err := s.queries.GetUserByEmail(ctx, email)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return nil, models.ErrAuthInvalid
|
|
}
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(dbUser.PasswordHash), []byte(password)); err != nil {
|
|
return nil, models.ErrAuthInvalid
|
|
}
|
|
|
|
userID := utils.FromPgUUID(dbUser.ID)
|
|
accessToken, err := s.jwt.GenerateAccessToken(userID)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
refreshToken, err := s.jwt.GenerateRefreshToken(userID)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
rtHash := hashToken(refreshToken)
|
|
_, err = s.queries.CreateRefreshToken(ctx, repository.CreateRefreshTokenParams{
|
|
ID: utils.ToPgUUID(uuid.New()),
|
|
UserID: utils.ToPgUUID(userID),
|
|
TokenHash: rtHash,
|
|
ExpiresAt: utils.ToPgTimestamptz(time.Now().Add(auth.RefreshTokenDuration)),
|
|
})
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
user := userFromEmailRow(dbUser)
|
|
return &models.AuthTokens{User: user, AccessToken: accessToken, RefreshToken: refreshToken}, nil
|
|
}
|
|
|
|
func (s *AuthService) Refresh(ctx context.Context, refreshTokenStr string) (*models.TokenPair, error) {
|
|
rtHash := hashToken(refreshTokenStr)
|
|
rt, err := s.queries.GetRefreshTokenByHash(ctx, rtHash)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return nil, models.ErrAuthInvalid
|
|
}
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
if utils.FromPgTimestamptz(rt.ExpiresAt).Before(time.Now()) {
|
|
return nil, models.ErrAuthInvalid
|
|
}
|
|
|
|
_ = s.queries.RevokeRefreshToken(ctx, rtHash)
|
|
|
|
userID := utils.FromPgUUID(rt.UserID)
|
|
accessToken, err := s.jwt.GenerateAccessToken(userID)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
newRefresh, err := s.jwt.GenerateRefreshToken(userID)
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
newHash := hashToken(newRefresh)
|
|
_, err = s.queries.CreateRefreshToken(ctx, repository.CreateRefreshTokenParams{
|
|
ID: utils.ToPgUUID(uuid.New()),
|
|
UserID: utils.ToPgUUID(userID),
|
|
TokenHash: newHash,
|
|
ExpiresAt: utils.ToPgTimestamptz(time.Now().Add(auth.RefreshTokenDuration)),
|
|
})
|
|
if err != nil {
|
|
return nil, models.ErrInternal
|
|
}
|
|
|
|
return &models.TokenPair{AccessToken: accessToken, RefreshToken: newRefresh}, nil
|
|
}
|
|
|
|
func (s *AuthService) Logout(ctx context.Context, refreshTokenStr string) error {
|
|
rtHash := hashToken(refreshTokenStr)
|
|
return s.queries.RevokeRefreshToken(ctx, rtHash)
|
|
}
|
|
|
|
func hashToken(token string) string {
|
|
h := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
func userFromCreateRow(u repository.CreateUserRow) models.User {
|
|
return models.User{
|
|
ID: utils.FromPgUUID(u.ID),
|
|
Email: u.Email,
|
|
Timezone: u.Timezone,
|
|
CreatedAt: utils.FromPgTimestamptz(u.CreatedAt),
|
|
UpdatedAt: utils.FromPgTimestamptz(u.UpdatedAt),
|
|
}
|
|
}
|
|
|
|
func userFromEmailRow(u repository.GetUserByEmailRow) models.User {
|
|
return models.User{
|
|
ID: utils.FromPgUUID(u.ID),
|
|
Email: u.Email,
|
|
Timezone: u.Timezone,
|
|
CreatedAt: utils.FromPgTimestamptz(u.CreatedAt),
|
|
UpdatedAt: utils.FromPgTimestamptz(u.UpdatedAt),
|
|
}
|
|
}
|
|
|
|
func userFromIDRow(u repository.GetUserByIDRow) models.User {
|
|
return models.User{
|
|
ID: utils.FromPgUUID(u.ID),
|
|
Email: u.Email,
|
|
Timezone: u.Timezone,
|
|
CreatedAt: utils.FromPgTimestamptz(u.CreatedAt),
|
|
UpdatedAt: utils.FromPgTimestamptz(u.UpdatedAt),
|
|
}
|
|
}
|
|
|
|
func userFromUpdateRow(u repository.UpdateUserRow) models.User {
|
|
return models.User{
|
|
ID: utils.FromPgUUID(u.ID),
|
|
Email: u.Email,
|
|
Timezone: u.Timezone,
|
|
CreatedAt: utils.FromPgTimestamptz(u.CreatedAt),
|
|
UpdatedAt: utils.FromPgTimestamptz(u.UpdatedAt),
|
|
}
|
|
}
|