query-database/api/internal/handlers/databases.go
2026-03-25 15:46:20 +08:00

297 lines
7.6 KiB
Go

package handlers
import (
"bytes"
"context"
"database/sql"
"io"
"net/http"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"query-database/api/internal/auth"
"query-database/api/internal/mockdata"
)
type userDatabaseItem struct {
ID string `json:"id"`
Name string `json:"name"`
Source string `json:"source"`
IsActive bool `json:"isActive"`
}
func (h *Handlers) ListMockDatabases(c *gin.Context) {
ms := mockdata.List()
out := make([]gin.H, 0, len(ms))
for _, m := range ms {
out = append(out, gin.H{"key": m.Key, "name": m.Name, "description": m.Description})
}
c.JSON(http.StatusOK, out)
}
func (h *Handlers) ListUserDatabases(c *gin.Context) {
userID := auth.UserID(c)
rows, err := h.sqlite.Query(
`SELECT id, name, source, is_active FROM user_databases WHERE user_id = ? ORDER BY created_at DESC`,
userID,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
defer rows.Close()
out := make([]userDatabaseItem, 0)
for rows.Next() {
var id, name, source string
var active int
if err := rows.Scan(&id, &name, &source, &active); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
out = append(out, userDatabaseItem{ID: id, Name: name, Source: source, IsActive: active == 1})
}
c.JSON(http.StatusOK, out)
}
type activateReq struct {
ID string `json:"id"`
}
func (h *Handlers) ActivateUserDatabase(c *gin.Context) {
userID := auth.UserID(c)
var req activateReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"})
return
}
if strings.TrimSpace(req.ID) == "" {
c.JSON(http.StatusBadRequest, gin.H{"message": "id 不能为空"})
return
}
tx, err := h.sqlite.Begin()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
defer tx.Rollback()
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
res, err := tx.Exec(`UPDATE user_databases SET is_active = 1 WHERE id = ? AND user_id = ?`, req.ID, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
n, _ := res.RowsAffected()
if n == 0 {
c.JSON(http.StatusNotFound, gin.H{"message": "数据库不存在"})
return
}
if err := tx.Commit(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
c.Status(http.StatusNoContent)
}
type activateMockReq struct {
Key string `json:"key"`
}
func (h *Handlers) ActivateMockDatabase(c *gin.Context) {
userID := auth.UserID(c)
var req activateMockReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"})
return
}
key := strings.TrimSpace(req.Key)
if key != "shop" && key != "hr" {
c.JSON(http.StatusBadRequest, gin.H{"message": "key 不合法"})
return
}
if err := h.activateMockForUser(userID, key, time.Now().UTC().Format(time.RFC3339)); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "设置失败"})
return
}
c.Status(http.StatusNoContent)
}
func (h *Handlers) activateMockForUser(userID string, key string, now string) error {
ms := mockdata.List()
var picked *mockdata.MockDatabase
for i := range ms {
if ms[i].Key == key {
picked = &ms[i]
break
}
}
if picked == nil {
return sql.ErrNoRows
}
tx, err := h.sqlite.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil {
return err
}
var existingID string
err = tx.QueryRow(
`SELECT id FROM user_databases WHERE user_id = ? AND source = 'mock' AND schema_name = ? LIMIT 1`,
userID,
picked.SchemaName,
).Scan(&existingID)
if err != nil {
if err == sql.ErrNoRows {
newID := uuid.NewString()
if _, err := tx.Exec(
`INSERT INTO user_databases (id, user_id, name, source, schema_name, is_active, created_at) VALUES (?, ?, ?, 'mock', ?, 1, ?)`,
newID,
userID,
picked.Name,
picked.SchemaName,
now,
); err != nil {
return err
}
} else {
return err
}
} else {
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 1 WHERE id = ?`, existingID); err != nil {
return err
}
}
return tx.Commit()
}
func (h *Handlers) ImportDatabase(c *gin.Context) {
userID := auth.UserID(c)
name := strings.TrimSpace(c.PostForm("name"))
if name == "" {
name = "我的数据库"
}
file, hdr, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "请选择文件"})
return
}
defer file.Close()
if strings.ToLower(filepath.Ext(hdr.Filename)) != ".sql" {
c.JSON(http.StatusBadRequest, gin.H{"message": "仅支持 .sql 文件"})
return
}
buf := new(bytes.Buffer)
if _, err := io.CopyN(buf, file, 2<<20); err != nil && err != io.EOF {
c.JSON(http.StatusBadRequest, gin.H{"message": "读取失败"})
return
}
sqlText := buf.String()
if strings.TrimSpace(sqlText) == "" {
c.JSON(http.StatusBadRequest, gin.H{"message": "SQL 为空"})
return
}
schema := makeSchemaName(userID)
if err := createSchemaAndInit(h.mysql, schema, sqlText); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "导入失败:" + err.Error()})
return
}
now := time.Now().UTC().Format(time.RFC3339)
tx, err := h.sqlite.Begin()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
defer tx.Rollback()
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
id := uuid.NewString()
if _, err := tx.Exec(
`INSERT INTO user_databases (id, user_id, name, source, schema_name, is_active, created_at)
VALUES (?, ?, ?, 'imported', ?, 1, ?)`,
id,
userID,
name,
schema,
now,
); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "保存失败"})
return
}
if err := tx.Commit(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
return
}
c.Status(http.StatusNoContent)
}
var reSchemaSan = regexp.MustCompile(`[^a-z0-9_]+`)
func makeSchemaName(userID string) string {
base := strings.ToLower(userID)
base = strings.ReplaceAll(base, "-", "_")
base = reSchemaSan.ReplaceAllString(base, "")
if len(base) > 12 {
base = base[:12]
}
return "udb_" + base + "_" + strings.ReplaceAll(uuid.NewString(), "-", "")[:8]
}
func quoteIdent(s string) string {
return "`" + strings.ReplaceAll(s, "`", "``") + "`"
}
func createSchemaAndInit(mysql *sql.DB, schema string, initSQL string) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
conn, err := mysql.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
if _, err := conn.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+quoteIdent(schema)+" CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci"); err != nil {
return err
}
if _, err := conn.ExecContext(ctx, "USE "+quoteIdent(schema)); err != nil {
return err
}
stmts := splitStatements(initSQL)
for _, s := range stmts {
if _, err := conn.ExecContext(ctx, s); err != nil {
return err
}
}
return nil
}
func splitStatements(sqlText string) []string {
parts := strings.Split(sqlText, ";")
out := make([]string, 0, len(parts))
for _, p := range parts {
s := strings.TrimSpace(p)
if s == "" {
continue
}
out = append(out, s)
}
return out
}