219 lines
6.0 KiB
Go
219 lines
6.0 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"time"
|
|
|
|
"cockpit/internal/config"
|
|
"cockpit/internal/domain"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Service struct {
|
|
cfg *config.Config
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewService(cfg *config.Config, db *gorm.DB) *Service {
|
|
return &Service{cfg: cfg, db: db}
|
|
}
|
|
|
|
type AccessClaims struct {
|
|
UserID uint64 `json:"uid"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type RefreshClaims struct {
|
|
UserID uint64 `json:"uid"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
func (s *Service) Login(ctx context.Context, username, password string) (accessToken, refreshToken string, user domain.User, permCodes []string, err error) {
|
|
if err := s.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
|
return "", "", domain.User{}, nil, errors.New("用户名或密码错误")
|
|
}
|
|
if !user.Enabled {
|
|
return "", "", domain.User{}, nil, errors.New("账号已禁用")
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
|
|
return "", "", domain.User{}, nil, errors.New("用户名或密码错误")
|
|
}
|
|
|
|
accessToken, err = s.NewAccessToken(user.ID)
|
|
if err != nil {
|
|
return "", "", domain.User{}, nil, err
|
|
}
|
|
refreshToken, err = s.NewRefreshToken(user.ID)
|
|
if err != nil {
|
|
return "", "", domain.User{}, nil, err
|
|
}
|
|
if err := s.saveRefreshToken(ctx, user.ID, refreshToken); err != nil {
|
|
return "", "", domain.User{}, nil, err
|
|
}
|
|
|
|
perms, err := s.GetUserPermCodes(ctx, user.ID)
|
|
if err != nil {
|
|
return "", "", domain.User{}, nil, err
|
|
}
|
|
return accessToken, refreshToken, user, perms, nil
|
|
}
|
|
|
|
func (s *Service) Refresh(ctx context.Context, refreshToken string) (newAccess, newRefresh string, err error) {
|
|
claims, err := s.ParseRefreshToken(refreshToken)
|
|
if err != nil {
|
|
return "", "", errors.New("refresh token 无效")
|
|
}
|
|
|
|
// 校验 DB 中是否存在且未撤销
|
|
hash := hashToken(refreshToken)
|
|
var rt domain.RefreshToken
|
|
if err := s.db.WithContext(ctx).Where("token_hash = ?", hash).First(&rt).Error; err != nil {
|
|
return "", "", errors.New("refresh token 已失效")
|
|
}
|
|
if rt.RevokedAt != nil || time.Now().After(rt.ExpiresAt) {
|
|
return "", "", errors.New("refresh token 已失效")
|
|
}
|
|
|
|
// 旋转刷新:撤销旧 token
|
|
now := time.Now()
|
|
_ = s.db.WithContext(ctx).Model(&domain.RefreshToken{}).
|
|
Where("id = ?", rt.ID).
|
|
Update("revoked_at", now).Error
|
|
|
|
newAccess, err = s.NewAccessToken(claims.UserID)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
newRefresh, err = s.NewRefreshToken(claims.UserID)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
if err := s.saveRefreshToken(ctx, claims.UserID, newRefresh); err != nil {
|
|
return "", "", err
|
|
}
|
|
return newAccess, newRefresh, nil
|
|
}
|
|
|
|
func (s *Service) Logout(ctx context.Context, refreshToken string) error {
|
|
hash := hashToken(refreshToken)
|
|
now := time.Now()
|
|
return s.db.WithContext(ctx).Model(&domain.RefreshToken{}).
|
|
Where("token_hash = ? AND revoked_at IS NULL", hash).
|
|
Update("revoked_at", now).Error
|
|
}
|
|
|
|
func (s *Service) NewAccessToken(userID uint64) (string, error) {
|
|
now := time.Now()
|
|
claims := AccessClaims{
|
|
UserID: userID,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: "cockpit",
|
|
Subject: "access",
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.cfg.Auth.AccessTokenTTL)),
|
|
},
|
|
}
|
|
t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return t.SignedString([]byte(s.cfg.Auth.AccessTokenSecret))
|
|
}
|
|
|
|
func (s *Service) NewRefreshToken(userID uint64) (string, error) {
|
|
now := time.Now()
|
|
jti, _ := randomHex(16)
|
|
claims := RefreshClaims{
|
|
UserID: userID,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: "cockpit",
|
|
Subject: "refresh",
|
|
ID: jti,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.cfg.Auth.RefreshTokenTTL)),
|
|
},
|
|
}
|
|
t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return t.SignedString([]byte(s.cfg.Auth.RefreshTokenSecret))
|
|
}
|
|
|
|
func (s *Service) ParseAccessToken(token string) (*AccessClaims, error) {
|
|
t, err := jwt.ParseWithClaims(token, &AccessClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(s.cfg.Auth.AccessTokenSecret), nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if claims, ok := t.Claims.(*AccessClaims); ok && t.Valid {
|
|
return claims, nil
|
|
}
|
|
return nil, errors.New("invalid token")
|
|
}
|
|
|
|
func (s *Service) ParseRefreshToken(token string) (*RefreshClaims, error) {
|
|
t, err := jwt.ParseWithClaims(token, &RefreshClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(s.cfg.Auth.RefreshTokenSecret), nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if claims, ok := t.Claims.(*RefreshClaims); ok && t.Valid {
|
|
return claims, nil
|
|
}
|
|
return nil, errors.New("invalid token")
|
|
}
|
|
|
|
func (s *Service) GetUserPermCodes(ctx context.Context, userID uint64) ([]string, error) {
|
|
// join user_roles -> role_permissions -> permissions
|
|
type row struct{ Code string }
|
|
var rows []row
|
|
err := s.db.WithContext(ctx).
|
|
Table("permissions").
|
|
Select("permissions.code as code").
|
|
Joins("JOIN role_permissions ON role_permissions.permission_id = permissions.id").
|
|
Joins("JOIN user_roles ON user_roles.role_id = role_permissions.role_id").
|
|
Where("user_roles.user_id = ?", userID).
|
|
Group("permissions.code").
|
|
Scan(&rows).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out := make([]string, 0, len(rows))
|
|
for _, r := range rows {
|
|
out = append(out, r.Code)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Service) saveRefreshToken(ctx context.Context, userID uint64, refreshToken string) error {
|
|
hash := hashToken(refreshToken)
|
|
claims, err := s.ParseRefreshToken(refreshToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rt := domain.RefreshToken{
|
|
UserID: userID,
|
|
TokenHash: hash,
|
|
ExpiresAt: claims.ExpiresAt.Time,
|
|
}
|
|
return s.db.WithContext(ctx).Create(&rt).Error
|
|
}
|
|
|
|
func hashToken(token string) string {
|
|
sum := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func randomHex(nBytes int) (string, error) {
|
|
b := make([]byte, nBytes)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|
|
|