267 lines
8.2 KiB
Go
267 lines
8.2 KiB
Go
package main
|
||
|
||
import (
|
||
"archive/zip" // 保持,因为 handleUpload 使用
|
||
"fmt"
|
||
"go-todo-api/constants"
|
||
"go-todo-api/handlers"
|
||
"go-todo-api/models"
|
||
"go-todo-api/repositories"
|
||
"go-todo-api/services"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/joho/godotenv" // 保持,用于加载配置
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/gorm"
|
||
// Go 标准库
|
||
)
|
||
|
||
// DB 是全局数据库连接变量
|
||
var DB *gorm.DB
|
||
|
||
// ❗ 新增:在 init() 中加载 .env 文件
|
||
func init() {
|
||
if err := godotenv.Load(); err != nil {
|
||
log.Println("Note: No .env file found, relying on environment variables.")
|
||
}
|
||
}
|
||
|
||
// 修正 connectDatabase 函数,使其读取环境变量
|
||
func connectDatabase() {
|
||
// ❗ 核心修正:使用 os.Getenv 获取配置
|
||
dbHost := os.Getenv("DB_HOST")
|
||
dbPort := os.Getenv("DB_PORT")
|
||
dbUser := os.Getenv("DB_USER")
|
||
dbPassword := os.Getenv("DB_PASSWORD")
|
||
dbName := os.Getenv("DB_NAME")
|
||
|
||
// 检查关键配置(如果未设置,则使用 os.Getenv 的默认空字符串,这会导致连接失败,但这是期望的安全行为)
|
||
if dbHost == "" || dbUser == "" || dbPassword == "" || dbName == "" {
|
||
log.Fatal("FATAL: Database environment variables (DB_HOST, DB_USER, etc.) are not fully set.")
|
||
}
|
||
|
||
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||
dbHost, dbUser, dbPassword, dbName, dbPort)
|
||
|
||
database, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
|
||
if err != nil {
|
||
log.Fatal("Failed to connect to PostgreSQL database! \n", err)
|
||
}
|
||
// ⭐ 核心步骤:自动迁移
|
||
err = database.AutoMigrate(&models.Todo{}, &models.User{})
|
||
if err != nil {
|
||
log.Fatal("Failed to auto-migrate database schema! \n", err)
|
||
}
|
||
|
||
DB = database
|
||
log.Println("PostgreSQL connection successful!") // 保持成功提示
|
||
}
|
||
|
||
// ❗ 修改 AuthRequired 函数签名,使其接受 AuthService 接口
|
||
func AuthRequired(authService services.AuthService) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 1. 从 Authorization 请求头中提取令牌
|
||
tokenString := c.GetHeader("Authorization")
|
||
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
|
||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing or invalid token format"})
|
||
c.Abort() // 阻止请求继续执行
|
||
return
|
||
}
|
||
|
||
// 移除 "Bearer " 前缀
|
||
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
|
||
|
||
// 2. ⭐ 调用 Service 层验证令牌
|
||
// Service 层会处理解析、验证签名和过期时间
|
||
subject, err := authService.AuthenticateToken(tokenString) // ❗ 必须在这里声明 subject
|
||
// 3. 处理解析错误 (如签名不匹配、令牌过期等)
|
||
if err != nil {
|
||
// Service 层的错误信息已经很通用了
|
||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
// 4. ⭐ 核心修改:将用户 ID 存储到 Gin Context 中
|
||
c.Set(constants.UserIDKey, subject) // ❗ 使用 constants.UserIDKey
|
||
|
||
// 5. 令牌有效,继续执行下一个 Handler
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// unzipFile 将 ZIP 文件解压到目标目录
|
||
func unzipFile(src, dest string) error {
|
||
r, err := zip.OpenReader(src)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer r.Close()
|
||
|
||
for _, f := range r.File {
|
||
// 1. 构建目标路径,确保文件不会逃逸到目标目录之外(安全检查)
|
||
fpath := filepath.Join(dest, f.Name)
|
||
if !filepath.HasPrefix(fpath, filepath.Clean(dest)+string(os.PathSeparator)) {
|
||
return fmt.Errorf("非法文件路径: %s", fpath)
|
||
}
|
||
|
||
// 2. 如果是目录,创建目录
|
||
if f.FileInfo().IsDir() {
|
||
os.MkdirAll(fpath, os.ModePerm)
|
||
continue
|
||
}
|
||
|
||
// 3. 打开文件进行读写
|
||
if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
|
||
return err
|
||
}
|
||
|
||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
rc, err := f.Open()
|
||
if err != nil {
|
||
outFile.Close()
|
||
return err
|
||
}
|
||
|
||
// 4. 复制文件内容
|
||
_, err = io.Copy(outFile, rc)
|
||
|
||
// 确保关闭文件
|
||
outFile.Close()
|
||
rc.Close()
|
||
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// handleUpload 是处理文件上传的函数
|
||
func handleUpload(c *gin.Context) {
|
||
// 1. 获取文本字段 (html_path)
|
||
htmlPath := c.PostForm("html_path")
|
||
if htmlPath == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing html_path field"})
|
||
return
|
||
}
|
||
|
||
timeDict := c.PostForm("time_dict")
|
||
tsNow := time.Now().Format("20060102150405")
|
||
if timeDict == "" {
|
||
timeDict = tsNow
|
||
}
|
||
|
||
// 2. 获取上传的文件 (zip_file)
|
||
file, err := c.FormFile("zip_file")
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing zip_file or failed to get file"})
|
||
return
|
||
}
|
||
// 目标解压目录(以当前时间戳或传入 time_dict 作为子目录)
|
||
destDir := filepath.Join("D:\\test-html-page", timeDict)
|
||
if err := os.MkdirAll(destDir, os.ModePerm); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create dest dir: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 1. 将 ZIP 文件保存到服务器的临时路径 (注意:我们先保存到临时目录)
|
||
tempZipPath := filepath.Join(os.TempDir(), file.Filename)
|
||
if err := c.SaveUploadedFile(file, tempZipPath); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save temp zip file: " + err.Error()})
|
||
return
|
||
}
|
||
// 确保函数结束时删除临时文件
|
||
defer os.Remove(tempZipPath)
|
||
|
||
// 2. ⭐ 调用解压函数
|
||
if err := unzipFile(tempZipPath, destDir); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to unzip file: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 3. 构建用户最终访问的 URL
|
||
// 为了简化,我们只返回访问路径。完整的访问 URL 需要下一步的 Gin 配置
|
||
accessPath := filepath.ToSlash(filepath.Join("/html", timeDict, htmlPath))
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "HTML files extracted and ready to serve.",
|
||
"final_access_url": accessPath,
|
||
})
|
||
}
|
||
|
||
// main 函数是 Go 程序的入口
|
||
func main() {
|
||
// 1. 切换到 Release Mode (生产模式)
|
||
// 这样做可以禁用调试输出,并优化性能。
|
||
gin.SetMode(gin.ReleaseMode)
|
||
|
||
connectDatabase()
|
||
|
||
// 确保 JWT Secret 的检查和获取逻辑不变
|
||
jwtSecret := os.Getenv("JWT_SECRET")
|
||
if jwtSecret == "" {
|
||
log.Fatal("FATAL: JWT_SECRET environment variable is not set. Please check your .env file.")
|
||
}
|
||
|
||
router := gin.Default()
|
||
|
||
// 2. 解决 "You trusted all proxies" 安全警告
|
||
// 当你的应用部署在 Nginx 或负载均衡器后面时,Gin 需要知道哪些 IP 是安全的代理。
|
||
// 以下配置信任了 Loopback 和所有私有网络范围(RFC1918),这是云部署的常见安全实践。
|
||
// 如果你知道你的代理 IP,使用更严格的配置更好。
|
||
router.SetTrustedProxies([]string{
|
||
"127.0.0.1",
|
||
"::1",
|
||
"10.0.0.0/8", // 私有网络 A 类
|
||
"172.16.0.0/12", // 私有网络 B 类
|
||
"192.168.0.0/16", // 私有网络 C 类
|
||
})
|
||
// --- 依赖注入 (DI) ---
|
||
todoRepo := repositories.NewGormTodoRepository(DB)
|
||
todoService := services.NewTodoService(todoRepo)
|
||
todoHandler := handlers.NewTodoHandler(todoService)
|
||
|
||
userRepo := repositories.NewGormUserRepository(DB)
|
||
authService := services.NewAuthService(userRepo, jwtSecret)
|
||
authHandler := handlers.NewAuthHandler(authService)
|
||
|
||
// 注册公开路由
|
||
router.GET("/hello", func(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{"message": "Hello from Gin, ready for Vue!"})
|
||
})
|
||
|
||
router.Static("/html", "D:/test-html-page")
|
||
router.POST("/register", authHandler.RegisterHandler)
|
||
router.POST("/login", authHandler.LoginHandler)
|
||
|
||
// 注册需要 AuthRequired 中间件的路由
|
||
authMiddleware := AuthRequired(authService)
|
||
|
||
router.POST("/upload", authMiddleware, handleUpload)
|
||
|
||
// Todo 受保护的 CRUD 路由
|
||
router.POST("/todos", authMiddleware, todoHandler.CreateTodoHandler)
|
||
router.GET("/todos", authMiddleware, todoHandler.FindAllTodosHandler)
|
||
router.GET("/todos/:id", authMiddleware, todoHandler.FindTodoByIDHandler)
|
||
router.PATCH("/todos/:id", authMiddleware, todoHandler.UpdateTodoHandler)
|
||
router.DELETE("/todos/:id", authMiddleware, todoHandler.DeleteTodoByIDHandler)
|
||
|
||
// 运行服务器
|
||
log.Printf("Server starting on :%s...", "8090") // 格式化输出
|
||
if err := router.Run(":8090"); err != nil {
|
||
panic("Server failed to start: " + err.Error())
|
||
}
|
||
}
|