cockpit-source/backend/internal/auth/service.go
2026-04-02 14:12:43 +08:00

219 lines
6.0 KiB
Go

package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"time"
"cockpit/internal/config"
"cockpit/internal/domain"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type Service struct {
cfg *config.Config
db *gorm.DB
}
func NewService(cfg *config.Config, db *gorm.DB) *Service {
return &Service{cfg: cfg, db: db}
}
type AccessClaims struct {
UserID uint64 `json:"uid"`
jwt.RegisteredClaims
}
type RefreshClaims struct {
UserID uint64 `json:"uid"`
jwt.RegisteredClaims
}
func (s *Service) Login(ctx context.Context, username, password string) (accessToken, refreshToken string, user domain.User, permCodes []string, err error) {
if err := s.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
return "", "", domain.User{}, nil, errors.New("用户名或密码错误")
}
if !user.Enabled {
return "", "", domain.User{}, nil, errors.New("账号已禁用")
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return "", "", domain.User{}, nil, errors.New("用户名或密码错误")
}
accessToken, err = s.NewAccessToken(user.ID)
if err != nil {
return "", "", domain.User{}, nil, err
}
refreshToken, err = s.NewRefreshToken(user.ID)
if err != nil {
return "", "", domain.User{}, nil, err
}
if err := s.saveRefreshToken(ctx, user.ID, refreshToken); err != nil {
return "", "", domain.User{}, nil, err
}
perms, err := s.GetUserPermCodes(ctx, user.ID)
if err != nil {
return "", "", domain.User{}, nil, err
}
return accessToken, refreshToken, user, perms, nil
}
func (s *Service) Refresh(ctx context.Context, refreshToken string) (newAccess, newRefresh string, err error) {
claims, err := s.ParseRefreshToken(refreshToken)
if err != nil {
return "", "", errors.New("refresh token 无效")
}
// 校验 DB 中是否存在且未撤销
hash := hashToken(refreshToken)
var rt domain.RefreshToken
if err := s.db.WithContext(ctx).Where("token_hash = ?", hash).First(&rt).Error; err != nil {
return "", "", errors.New("refresh token 已失效")
}
if rt.RevokedAt != nil || time.Now().After(rt.ExpiresAt) {
return "", "", errors.New("refresh token 已失效")
}
// 旋转刷新:撤销旧 token
now := time.Now()
_ = s.db.WithContext(ctx).Model(&domain.RefreshToken{}).
Where("id = ?", rt.ID).
Update("revoked_at", now).Error
newAccess, err = s.NewAccessToken(claims.UserID)
if err != nil {
return "", "", err
}
newRefresh, err = s.NewRefreshToken(claims.UserID)
if err != nil {
return "", "", err
}
if err := s.saveRefreshToken(ctx, claims.UserID, newRefresh); err != nil {
return "", "", err
}
return newAccess, newRefresh, nil
}
func (s *Service) Logout(ctx context.Context, refreshToken string) error {
hash := hashToken(refreshToken)
now := time.Now()
return s.db.WithContext(ctx).Model(&domain.RefreshToken{}).
Where("token_hash = ? AND revoked_at IS NULL", hash).
Update("revoked_at", now).Error
}
func (s *Service) NewAccessToken(userID uint64) (string, error) {
now := time.Now()
claims := AccessClaims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "cockpit",
Subject: "access",
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.cfg.Auth.AccessTokenTTL)),
},
}
t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return t.SignedString([]byte(s.cfg.Auth.AccessTokenSecret))
}
func (s *Service) NewRefreshToken(userID uint64) (string, error) {
now := time.Now()
jti, _ := randomHex(16)
claims := RefreshClaims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "cockpit",
Subject: "refresh",
ID: jti,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.cfg.Auth.RefreshTokenTTL)),
},
}
t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return t.SignedString([]byte(s.cfg.Auth.RefreshTokenSecret))
}
func (s *Service) ParseAccessToken(token string) (*AccessClaims, error) {
t, err := jwt.ParseWithClaims(token, &AccessClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(s.cfg.Auth.AccessTokenSecret), nil
})
if err != nil {
return nil, err
}
if claims, ok := t.Claims.(*AccessClaims); ok && t.Valid {
return claims, nil
}
return nil, errors.New("invalid token")
}
func (s *Service) ParseRefreshToken(token string) (*RefreshClaims, error) {
t, err := jwt.ParseWithClaims(token, &RefreshClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(s.cfg.Auth.RefreshTokenSecret), nil
})
if err != nil {
return nil, err
}
if claims, ok := t.Claims.(*RefreshClaims); ok && t.Valid {
return claims, nil
}
return nil, errors.New("invalid token")
}
func (s *Service) GetUserPermCodes(ctx context.Context, userID uint64) ([]string, error) {
// join user_roles -> role_permissions -> permissions
type row struct{ Code string }
var rows []row
err := s.db.WithContext(ctx).
Table("permissions").
Select("permissions.code as code").
Joins("JOIN role_permissions ON role_permissions.permission_id = permissions.id").
Joins("JOIN user_roles ON user_roles.role_id = role_permissions.role_id").
Where("user_roles.user_id = ?", userID).
Group("permissions.code").
Scan(&rows).Error
if err != nil {
return nil, err
}
out := make([]string, 0, len(rows))
for _, r := range rows {
out = append(out, r.Code)
}
return out, nil
}
func (s *Service) saveRefreshToken(ctx context.Context, userID uint64, refreshToken string) error {
hash := hashToken(refreshToken)
claims, err := s.ParseRefreshToken(refreshToken)
if err != nil {
return err
}
rt := domain.RefreshToken{
UserID: userID,
TokenHash: hash,
ExpiresAt: claims.ExpiresAt.Time,
}
return s.db.WithContext(ctx).Create(&rt).Error
}
func hashToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func randomHex(nBytes int) (string, error) {
b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}