124 lines
3.5 KiB
Go
124 lines
3.5 KiB
Go
// services/auth_service_impl.go
|
||
package services
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"time"
|
||
|
||
"go-todo-api/dto"
|
||
"go-todo-api/models"
|
||
"go-todo-api/repositories"
|
||
|
||
"github.com/golang-jwt/jwt/v5"
|
||
"golang.org/x/crypto/bcrypt"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// 定义一个秘密常量,你需要确保在整个应用中统一使用它!
|
||
// 建议从配置文件或环境变量读取
|
||
const jwtSecret = "aabbccdd66778899.."
|
||
|
||
// AuthServiceImpl 实现了 AuthService 接口
|
||
type AuthServiceImpl struct {
|
||
UserRepo repositories.UserRepository
|
||
// ❗ 核心修正:添加 JWTSecret 字段
|
||
JWTSecret string
|
||
}
|
||
|
||
// NewAuthService 创建 AuthServiceImpl 的新实例
|
||
func NewAuthService(userRepo repositories.UserRepository, secret string) *AuthServiceImpl {
|
||
return &AuthServiceImpl{UserRepo: userRepo, JWTSecret: secret}
|
||
}
|
||
|
||
// Register 处理用户注册逻辑
|
||
func (s *AuthServiceImpl) Register(input *dto.RegisterInput) (*models.User, error) {
|
||
// 1. 检查用户是否已存在
|
||
_, err := s.UserRepo.FindByUsername(input.Username)
|
||
if err == nil {
|
||
// 如果 FindByUsername 没有返回错误,说明用户已存在
|
||
return nil, fmt.Errorf("username already taken")
|
||
}
|
||
// 如果错误不是 gorm.ErrRecordNotFound,则返回其他错误
|
||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, err
|
||
}
|
||
|
||
// 2. 密码哈希
|
||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||
}
|
||
|
||
// 3. 创建用户模型
|
||
newUser := models.User{
|
||
Username: input.Username,
|
||
Password: string(hashedPassword),
|
||
}
|
||
|
||
// 4. 保存到数据库
|
||
if err := s.UserRepo.Create(&newUser); err != nil {
|
||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||
}
|
||
|
||
return &newUser, nil
|
||
}
|
||
|
||
// Login 处理用户登录逻辑,返回 JWT 令牌
|
||
func (s *AuthServiceImpl) Login(input *dto.LoginInput) (string, error) {
|
||
// 1. 查找用户
|
||
foundUser, err := s.UserRepo.FindByUsername(input.Username)
|
||
if err != nil {
|
||
// 统一返回授权错误,避免泄露用户不存在的信息
|
||
return "", fmt.Errorf("invalid username or password")
|
||
}
|
||
|
||
// 2. 密码比较
|
||
err = bcrypt.CompareHashAndPassword([]byte(foundUser.Password), []byte(input.Password))
|
||
if err != nil {
|
||
return "", fmt.Errorf("invalid username or password")
|
||
}
|
||
|
||
// 3. JWT 令牌生成
|
||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
|
||
Subject: fmt.Sprintf("%d", foundUser.ID),
|
||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)),
|
||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||
})
|
||
|
||
// 4. 签名令牌
|
||
tokenString, err := claims.SignedString([]byte(jwtSecret))
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||
}
|
||
|
||
return tokenString, nil
|
||
}
|
||
|
||
// AuthenticateToken 验证 JWT 令牌
|
||
func (s *AuthServiceImpl) AuthenticateToken(tokenString string) (string, error) {
|
||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||
}
|
||
return []byte(jwtSecret), nil
|
||
})
|
||
|
||
if err != nil || !token.Valid {
|
||
return "", fmt.Errorf("invalid or expired token")
|
||
}
|
||
|
||
// 从 Claims 中提取 Sub (用户 ID)
|
||
claims, ok := token.Claims.(jwt.MapClaims)
|
||
if !ok {
|
||
return "", fmt.Errorf("invalid token claims")
|
||
}
|
||
|
||
sub, err := claims.GetSubject()
|
||
if err != nil {
|
||
return "", fmt.Errorf("missing subject in token claims")
|
||
}
|
||
|
||
return sub, nil // 返回用户 ID 字符串
|
||
}
|