package auth import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "errors" "time" "cockpit/internal/config" "cockpit/internal/domain" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) type Service struct { cfg *config.Config db *gorm.DB } func NewService(cfg *config.Config, db *gorm.DB) *Service { return &Service{cfg: cfg, db: db} } type AccessClaims struct { UserID uint64 `json:"uid"` jwt.RegisteredClaims } type RefreshClaims struct { UserID uint64 `json:"uid"` jwt.RegisteredClaims } func (s *Service) Login(ctx context.Context, username, password string) (accessToken, refreshToken string, user domain.User, permCodes []string, err error) { if err := s.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil { return "", "", domain.User{}, nil, errors.New("用户名或密码错误") } if !user.Enabled { return "", "", domain.User{}, nil, errors.New("账号已禁用") } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { return "", "", domain.User{}, nil, errors.New("用户名或密码错误") } accessToken, err = s.NewAccessToken(user.ID) if err != nil { return "", "", domain.User{}, nil, err } refreshToken, err = s.NewRefreshToken(user.ID) if err != nil { return "", "", domain.User{}, nil, err } if err := s.saveRefreshToken(ctx, user.ID, refreshToken); err != nil { return "", "", domain.User{}, nil, err } perms, err := s.GetUserPermCodes(ctx, user.ID) if err != nil { return "", "", domain.User{}, nil, err } return accessToken, refreshToken, user, perms, nil } func (s *Service) Refresh(ctx context.Context, refreshToken string) (newAccess, newRefresh string, err error) { claims, err := s.ParseRefreshToken(refreshToken) if err != nil { return "", "", errors.New("refresh token 无效") } // 校验 DB 中是否存在且未撤销 hash := hashToken(refreshToken) var rt domain.RefreshToken if err := s.db.WithContext(ctx).Where("token_hash = ?", hash).First(&rt).Error; err != nil { return "", "", errors.New("refresh token 已失效") } if rt.RevokedAt != nil || time.Now().After(rt.ExpiresAt) { return "", "", errors.New("refresh token 已失效") } // 旋转刷新:撤销旧 token now := time.Now() _ = s.db.WithContext(ctx).Model(&domain.RefreshToken{}). Where("id = ?", rt.ID). Update("revoked_at", now).Error newAccess, err = s.NewAccessToken(claims.UserID) if err != nil { return "", "", err } newRefresh, err = s.NewRefreshToken(claims.UserID) if err != nil { return "", "", err } if err := s.saveRefreshToken(ctx, claims.UserID, newRefresh); err != nil { return "", "", err } return newAccess, newRefresh, nil } func (s *Service) Logout(ctx context.Context, refreshToken string) error { hash := hashToken(refreshToken) now := time.Now() return s.db.WithContext(ctx).Model(&domain.RefreshToken{}). Where("token_hash = ? AND revoked_at IS NULL", hash). Update("revoked_at", now).Error } func (s *Service) NewAccessToken(userID uint64) (string, error) { now := time.Now() claims := AccessClaims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ Issuer: "cockpit", Subject: "access", IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(s.cfg.Auth.AccessTokenTTL)), }, } t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return t.SignedString([]byte(s.cfg.Auth.AccessTokenSecret)) } func (s *Service) NewRefreshToken(userID uint64) (string, error) { now := time.Now() jti, _ := randomHex(16) claims := RefreshClaims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ Issuer: "cockpit", Subject: "refresh", ID: jti, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(s.cfg.Auth.RefreshTokenTTL)), }, } t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return t.SignedString([]byte(s.cfg.Auth.RefreshTokenSecret)) } func (s *Service) ParseAccessToken(token string) (*AccessClaims, error) { t, err := jwt.ParseWithClaims(token, &AccessClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(s.cfg.Auth.AccessTokenSecret), nil }) if err != nil { return nil, err } if claims, ok := t.Claims.(*AccessClaims); ok && t.Valid { return claims, nil } return nil, errors.New("invalid token") } func (s *Service) ParseRefreshToken(token string) (*RefreshClaims, error) { t, err := jwt.ParseWithClaims(token, &RefreshClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(s.cfg.Auth.RefreshTokenSecret), nil }) if err != nil { return nil, err } if claims, ok := t.Claims.(*RefreshClaims); ok && t.Valid { return claims, nil } return nil, errors.New("invalid token") } func (s *Service) GetUserPermCodes(ctx context.Context, userID uint64) ([]string, error) { // join user_roles -> role_permissions -> permissions type row struct{ Code string } var rows []row err := s.db.WithContext(ctx). Table("permissions"). Select("permissions.code as code"). Joins("JOIN role_permissions ON role_permissions.permission_id = permissions.id"). Joins("JOIN user_roles ON user_roles.role_id = role_permissions.role_id"). Where("user_roles.user_id = ?", userID). Group("permissions.code"). Scan(&rows).Error if err != nil { return nil, err } out := make([]string, 0, len(rows)) for _, r := range rows { out = append(out, r.Code) } return out, nil } func (s *Service) saveRefreshToken(ctx context.Context, userID uint64, refreshToken string) error { hash := hashToken(refreshToken) claims, err := s.ParseRefreshToken(refreshToken) if err != nil { return err } rt := domain.RefreshToken{ UserID: userID, TokenHash: hash, ExpiresAt: claims.ExpiresAt.Time, } return s.db.WithContext(ctx).Create(&rt).Error } func hashToken(token string) string { sum := sha256.Sum256([]byte(token)) return hex.EncodeToString(sum[:]) } func randomHex(nBytes int) (string, error) { b := make([]byte, nBytes) if _, err := rand.Read(b); err != nil { return "", err } return hex.EncodeToString(b), nil }