go-todo-api/main.go
2025-12-02 18:58:25 +08:00

267 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"archive/zip" // 保持,因为 handleUpload 使用
"fmt"
"go-todo-api/constants"
"go-todo-api/handlers"
"go-todo-api/models"
"go-todo-api/repositories"
"go-todo-api/services"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/joho/godotenv" // 保持,用于加载配置
"gorm.io/driver/postgres"
"gorm.io/gorm"
// Go 标准库
)
// DB 是全局数据库连接变量
var DB *gorm.DB
// ❗ 新增:在 init() 中加载 .env 文件
func init() {
if err := godotenv.Load(); err != nil {
log.Println("Note: No .env file found, relying on environment variables.")
}
}
// 修正 connectDatabase 函数,使其读取环境变量
func connectDatabase() {
// ❗ 核心修正:使用 os.Getenv 获取配置
dbHost := os.Getenv("DB_HOST")
dbPort := os.Getenv("DB_PORT")
dbUser := os.Getenv("DB_USER")
dbPassword := os.Getenv("DB_PASSWORD")
dbName := os.Getenv("DB_NAME")
// 检查关键配置(如果未设置,则使用 os.Getenv 的默认空字符串,这会导致连接失败,但这是期望的安全行为)
if dbHost == "" || dbUser == "" || dbPassword == "" || dbName == "" {
log.Fatal("FATAL: Database environment variables (DB_HOST, DB_USER, etc.) are not fully set.")
}
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=Asia/Shanghai",
dbHost, dbUser, dbPassword, dbName, dbPort)
database, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
log.Fatal("Failed to connect to PostgreSQL database! \n", err)
}
// ⭐ 核心步骤:自动迁移
err = database.AutoMigrate(&models.Todo{}, &models.User{})
if err != nil {
log.Fatal("Failed to auto-migrate database schema! \n", err)
}
DB = database
log.Println("PostgreSQL connection successful!") // 保持成功提示
}
// ❗ 修改 AuthRequired 函数签名,使其接受 AuthService 接口
func AuthRequired(authService services.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
// 1. 从 Authorization 请求头中提取令牌
tokenString := c.GetHeader("Authorization")
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing or invalid token format"})
c.Abort() // 阻止请求继续执行
return
}
// 移除 "Bearer " 前缀
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
// 2. ⭐ 调用 Service 层验证令牌
// Service 层会处理解析、验证签名和过期时间
subject, err := authService.AuthenticateToken(tokenString) // ❗ 必须在这里声明 subject
// 3. 处理解析错误 (如签名不匹配、令牌过期等)
if err != nil {
// Service 层的错误信息已经很通用了
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
c.Abort()
return
}
// 4. ⭐ 核心修改:将用户 ID 存储到 Gin Context 中
c.Set(constants.UserIDKey, subject) // ❗ 使用 constants.UserIDKey
// 5. 令牌有效,继续执行下一个 Handler
c.Next()
}
}
// unzipFile 将 ZIP 文件解压到目标目录
func unzipFile(src, dest string) error {
r, err := zip.OpenReader(src)
if err != nil {
return err
}
defer r.Close()
for _, f := range r.File {
// 1. 构建目标路径,确保文件不会逃逸到目标目录之外(安全检查)
fpath := filepath.Join(dest, f.Name)
if !filepath.HasPrefix(fpath, filepath.Clean(dest)+string(os.PathSeparator)) {
return fmt.Errorf("非法文件路径: %s", fpath)
}
// 2. 如果是目录,创建目录
if f.FileInfo().IsDir() {
os.MkdirAll(fpath, os.ModePerm)
continue
}
// 3. 打开文件进行读写
if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
return err
}
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return err
}
rc, err := f.Open()
if err != nil {
outFile.Close()
return err
}
// 4. 复制文件内容
_, err = io.Copy(outFile, rc)
// 确保关闭文件
outFile.Close()
rc.Close()
if err != nil {
return err
}
}
return nil
}
// handleUpload 是处理文件上传的函数
func handleUpload(c *gin.Context) {
// 1. 获取文本字段 (html_path)
htmlPath := c.PostForm("html_path")
if htmlPath == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing html_path field"})
return
}
timeDict := c.PostForm("time_dict")
tsNow := time.Now().Format("20060102150405")
if timeDict == "" {
timeDict = tsNow
}
// 2. 获取上传的文件 (zip_file)
file, err := c.FormFile("zip_file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing zip_file or failed to get file"})
return
}
// 目标解压目录(以当前时间戳或传入 time_dict 作为子目录)
destDir := filepath.Join("D:\\test-html-page", timeDict)
if err := os.MkdirAll(destDir, os.ModePerm); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create dest dir: " + err.Error()})
return
}
// 1. 将 ZIP 文件保存到服务器的临时路径 (注意:我们先保存到临时目录)
tempZipPath := filepath.Join(os.TempDir(), file.Filename)
if err := c.SaveUploadedFile(file, tempZipPath); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save temp zip file: " + err.Error()})
return
}
// 确保函数结束时删除临时文件
defer os.Remove(tempZipPath)
// 2. ⭐ 调用解压函数
if err := unzipFile(tempZipPath, destDir); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to unzip file: " + err.Error()})
return
}
// 3. 构建用户最终访问的 URL
// 为了简化,我们只返回访问路径。完整的访问 URL 需要下一步的 Gin 配置
accessPath := filepath.ToSlash(filepath.Join("/html", timeDict, htmlPath))
c.JSON(http.StatusOK, gin.H{
"message": "HTML files extracted and ready to serve.",
"final_access_url": accessPath,
})
}
// main 函数是 Go 程序的入口
func main() {
// 1. 切换到 Release Mode (生产模式)
// 这样做可以禁用调试输出,并优化性能。
gin.SetMode(gin.ReleaseMode)
connectDatabase()
// 确保 JWT Secret 的检查和获取逻辑不变
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
log.Fatal("FATAL: JWT_SECRET environment variable is not set. Please check your .env file.")
}
router := gin.Default()
// 2. 解决 "You trusted all proxies" 安全警告
// 当你的应用部署在 Nginx 或负载均衡器后面时Gin 需要知道哪些 IP 是安全的代理。
// 以下配置信任了 Loopback 和所有私有网络范围RFC1918这是云部署的常见安全实践。
// 如果你知道你的代理 IP使用更严格的配置更好。
router.SetTrustedProxies([]string{
"127.0.0.1",
"::1",
"10.0.0.0/8", // 私有网络 A 类
"172.16.0.0/12", // 私有网络 B 类
"192.168.0.0/16", // 私有网络 C 类
})
// --- 依赖注入 (DI) ---
todoRepo := repositories.NewGormTodoRepository(DB)
todoService := services.NewTodoService(todoRepo)
todoHandler := handlers.NewTodoHandler(todoService)
userRepo := repositories.NewGormUserRepository(DB)
authService := services.NewAuthService(userRepo, jwtSecret)
authHandler := handlers.NewAuthHandler(authService)
// 注册公开路由
router.GET("/hello", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Hello from Gin, ready for Vue!"})
})
router.Static("/html", "D:/test-html-page")
router.POST("/register", authHandler.RegisterHandler)
router.POST("/login", authHandler.LoginHandler)
// 注册需要 AuthRequired 中间件的路由
authMiddleware := AuthRequired(authService)
router.POST("/upload", authMiddleware, handleUpload)
// Todo 受保护的 CRUD 路由
router.POST("/todos", authMiddleware, todoHandler.CreateTodoHandler)
router.GET("/todos", authMiddleware, todoHandler.FindAllTodosHandler)
router.GET("/todos/:id", authMiddleware, todoHandler.FindTodoByIDHandler)
router.PATCH("/todos/:id", authMiddleware, todoHandler.UpdateTodoHandler)
router.DELETE("/todos/:id", authMiddleware, todoHandler.DeleteTodoByIDHandler)
// 运行服务器
log.Printf("Server starting on :%s...", "8090") // 格式化输出
if err := router.Run(":8090"); err != nil {
panic("Server failed to start: " + err.Error())
}
}