package handlers import ( "bytes" "context" "database/sql" "io" "net/http" "path/filepath" "regexp" "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "query-database/api/internal/auth" "query-database/api/internal/mockdata" ) type userDatabaseItem struct { ID string `json:"id"` Name string `json:"name"` Source string `json:"source"` IsActive bool `json:"isActive"` } func (h *Handlers) ListMockDatabases(c *gin.Context) { ms := mockdata.List() out := make([]gin.H, 0, len(ms)) for _, m := range ms { out = append(out, gin.H{"key": m.Key, "name": m.Name, "description": m.Description}) } c.JSON(http.StatusOK, out) } func (h *Handlers) ListUserDatabases(c *gin.Context) { userID := auth.UserID(c) rows, err := h.sqlite.Query( `SELECT id, name, source, is_active FROM user_databases WHERE user_id = ? ORDER BY created_at DESC`, userID, ) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } defer rows.Close() out := make([]userDatabaseItem, 0) for rows.Next() { var id, name, source string var active int if err := rows.Scan(&id, &name, &source, &active); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } out = append(out, userDatabaseItem{ID: id, Name: name, Source: source, IsActive: active == 1}) } c.JSON(http.StatusOK, out) } type activateReq struct { ID string `json:"id"` } func (h *Handlers) ActivateUserDatabase(c *gin.Context) { userID := auth.UserID(c) var req activateReq if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"}) return } if strings.TrimSpace(req.ID) == "" { c.JSON(http.StatusBadRequest, gin.H{"message": "id 不能为空"}) return } tx, err := h.sqlite.Begin() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } defer tx.Rollback() if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } res, err := tx.Exec(`UPDATE user_databases SET is_active = 1 WHERE id = ? AND user_id = ?`, req.ID, userID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } n, _ := res.RowsAffected() if n == 0 { c.JSON(http.StatusNotFound, gin.H{"message": "数据库不存在"}) return } if err := tx.Commit(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } c.Status(http.StatusNoContent) } type activateMockReq struct { Key string `json:"key"` } func (h *Handlers) ActivateMockDatabase(c *gin.Context) { userID := auth.UserID(c) var req activateMockReq if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"}) return } key := strings.TrimSpace(req.Key) if key != "shop" && key != "hr" { c.JSON(http.StatusBadRequest, gin.H{"message": "key 不合法"}) return } if err := h.activateMockForUser(userID, key, time.Now().UTC().Format(time.RFC3339)); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "设置失败"}) return } c.Status(http.StatusNoContent) } func (h *Handlers) activateMockForUser(userID string, key string, now string) error { ms := mockdata.List() var picked *mockdata.MockDatabase for i := range ms { if ms[i].Key == key { picked = &ms[i] break } } if picked == nil { return sql.ErrNoRows } tx, err := h.sqlite.Begin() if err != nil { return err } defer tx.Rollback() if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil { return err } var existingID string err = tx.QueryRow( `SELECT id FROM user_databases WHERE user_id = ? AND source = 'mock' AND schema_name = ? LIMIT 1`, userID, picked.SchemaName, ).Scan(&existingID) if err != nil { if err == sql.ErrNoRows { newID := uuid.NewString() if _, err := tx.Exec( `INSERT INTO user_databases (id, user_id, name, source, schema_name, is_active, created_at) VALUES (?, ?, ?, 'mock', ?, 1, ?)`, newID, userID, picked.Name, picked.SchemaName, now, ); err != nil { return err } } else { return err } } else { if _, err := tx.Exec(`UPDATE user_databases SET is_active = 1 WHERE id = ?`, existingID); err != nil { return err } } return tx.Commit() } func (h *Handlers) ImportDatabase(c *gin.Context) { userID := auth.UserID(c) name := strings.TrimSpace(c.PostForm("name")) if name == "" { name = "我的数据库" } file, hdr, err := c.Request.FormFile("file") if err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": "请选择文件"}) return } defer file.Close() if strings.ToLower(filepath.Ext(hdr.Filename)) != ".sql" { c.JSON(http.StatusBadRequest, gin.H{"message": "仅支持 .sql 文件"}) return } buf := new(bytes.Buffer) if _, err := io.CopyN(buf, file, 2<<20); err != nil && err != io.EOF { c.JSON(http.StatusBadRequest, gin.H{"message": "读取失败"}) return } sqlText := buf.String() if strings.TrimSpace(sqlText) == "" { c.JSON(http.StatusBadRequest, gin.H{"message": "SQL 为空"}) return } schema := makeSchemaName(userID) if err := createSchemaAndInit(h.mysql, schema, sqlText); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": "导入失败:" + err.Error()}) return } now := time.Now().UTC().Format(time.RFC3339) tx, err := h.sqlite.Begin() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } defer tx.Rollback() if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } id := uuid.NewString() if _, err := tx.Exec( `INSERT INTO user_databases (id, user_id, name, source, schema_name, is_active, created_at) VALUES (?, ?, ?, 'imported', ?, 1, ?)`, id, userID, name, schema, now, ); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "保存失败"}) return } if err := tx.Commit(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"}) return } c.Status(http.StatusNoContent) } var reSchemaSan = regexp.MustCompile(`[^a-z0-9_]+`) func makeSchemaName(userID string) string { base := strings.ToLower(userID) base = strings.ReplaceAll(base, "-", "_") base = reSchemaSan.ReplaceAllString(base, "") if len(base) > 12 { base = base[:12] } return "udb_" + base + "_" + strings.ReplaceAll(uuid.NewString(), "-", "")[:8] } func quoteIdent(s string) string { return "`" + strings.ReplaceAll(s, "`", "``") + "`" } func createSchemaAndInit(mysql *sql.DB, schema string, initSQL string) error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() conn, err := mysql.Conn(ctx) if err != nil { return err } defer conn.Close() if _, err := conn.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+quoteIdent(schema)+" CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci"); err != nil { return err } if _, err := conn.ExecContext(ctx, "USE "+quoteIdent(schema)); err != nil { return err } stmts := splitStatements(initSQL) for _, s := range stmts { if _, err := conn.ExecContext(ctx, s); err != nil { return err } } return nil } func splitStatements(sqlText string) []string { parts := strings.Split(sqlText, ";") out := make([]string, 0, len(parts)) for _, p := range parts { s := strings.TrimSpace(p) if s == "" { continue } out = append(out, s) } return out }