query-database/api/main.go
2026-03-25 15:46:20 +08:00

117 lines
2.7 KiB
Go

package main
import (
"log"
"net/http"
"os"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"query-database/api/internal/auth"
"query-database/api/internal/config"
"query-database/api/internal/db"
"query-database/api/internal/handlers"
"query-database/api/internal/mockdata"
)
func main() {
cfg := config.Load()
sqliteDB, err := db.OpenSQLite(cfg.SQLitePath)
if err != nil {
log.Fatal(err)
}
if err := db.MigrateSQLite(sqliteDB); err != nil {
log.Fatal(err)
}
if err := db.SeedSQLite(sqliteDB); err != nil {
log.Fatal(err)
}
mysqlDB, err := db.OpenMySQL(cfg)
if err != nil {
log.Fatal(err)
}
if err := mockdata.EnsureMockSchemas(mysqlDB); err != nil {
log.Fatal(err)
}
a := auth.New(cfg.JWTSecret)
h := handlers.New(handlers.Deps{
Cfg: cfg,
Auth: a,
SQLite: sqliteDB,
MySQL: mysqlDB,
})
r := gin.New()
r.Use(gin.Recovery())
r.Use(gin.LoggerWithFormatter(func(p gin.LogFormatterParams) string {
return ""
}))
allowOrigins := []string{"http://localhost:5173"}
if strings.TrimSpace(cfg.CORSAllowAll) == "1" {
allowOrigins = []string{"*"}
} else if strings.TrimSpace(cfg.CORSOrigins) != "" {
parts := strings.Split(cfg.CORSOrigins, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v != "" {
out = append(out, v)
}
}
if len(out) > 0 {
allowOrigins = out
}
}
r.Use(cors.New(cors.Config{
AllowOrigins: allowOrigins,
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Authorization", "Content-Type"},
AllowCredentials: false,
MaxAge: 12 * time.Hour,
}))
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
api := r.Group("/api")
api.POST("/auth/register", h.Register)
api.POST("/auth/login", h.Login)
api.GET("/mock-databases", a.RequireAuth(sqliteDB), h.ListMockDatabases)
authed := api.Group("")
authed.Use(a.RequireAuth(sqliteDB))
authed.GET("/me", h.Me)
authed.POST("/onboarding/complete", h.CompleteOnboarding)
authed.POST("/module/switch", h.SwitchModule)
authed.GET("/exercises", h.ListExercises)
authed.GET("/exercises/:id", h.GetExercise)
authed.POST("/sql/execute", h.ExecuteSQL)
authed.POST("/progress/upsert", h.UpsertProgress)
authed.GET("/user-databases", h.ListUserDatabases)
authed.POST("/user-databases/activate", h.ActivateUserDatabase)
authed.POST("/user-databases/activate-mock", h.ActivateMockDatabase)
authed.POST("/user-databases/import", h.ImportDatabase)
port := cfg.Port
if port == "" {
port = "8080"
}
if err := r.Run(":" + port); err != nil {
log.Println(err)
os.Exit(1)
}
}