297 lines
7.6 KiB
Go
297 lines
7.6 KiB
Go
package handlers
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"io"
|
|
"net/http"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
|
|
"query-database/api/internal/auth"
|
|
"query-database/api/internal/mockdata"
|
|
)
|
|
|
|
type userDatabaseItem struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Source string `json:"source"`
|
|
IsActive bool `json:"isActive"`
|
|
}
|
|
|
|
func (h *Handlers) ListMockDatabases(c *gin.Context) {
|
|
ms := mockdata.List()
|
|
out := make([]gin.H, 0, len(ms))
|
|
for _, m := range ms {
|
|
out = append(out, gin.H{"key": m.Key, "name": m.Name, "description": m.Description})
|
|
}
|
|
c.JSON(http.StatusOK, out)
|
|
}
|
|
|
|
func (h *Handlers) ListUserDatabases(c *gin.Context) {
|
|
userID := auth.UserID(c)
|
|
rows, err := h.sqlite.Query(
|
|
`SELECT id, name, source, is_active FROM user_databases WHERE user_id = ? ORDER BY created_at DESC`,
|
|
userID,
|
|
)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
out := make([]userDatabaseItem, 0)
|
|
for rows.Next() {
|
|
var id, name, source string
|
|
var active int
|
|
if err := rows.Scan(&id, &name, &source, &active); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
out = append(out, userDatabaseItem{ID: id, Name: name, Source: source, IsActive: active == 1})
|
|
}
|
|
c.JSON(http.StatusOK, out)
|
|
}
|
|
|
|
type activateReq struct {
|
|
ID string `json:"id"`
|
|
}
|
|
|
|
func (h *Handlers) ActivateUserDatabase(c *gin.Context) {
|
|
userID := auth.UserID(c)
|
|
var req activateReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"})
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.ID) == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "id 不能为空"})
|
|
return
|
|
}
|
|
tx, err := h.sqlite.Begin()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
res, err := tx.Exec(`UPDATE user_databases SET is_active = 1 WHERE id = ? AND user_id = ?`, req.ID, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
c.JSON(http.StatusNotFound, gin.H{"message": "数据库不存在"})
|
|
return
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
c.Status(http.StatusNoContent)
|
|
}
|
|
|
|
type activateMockReq struct {
|
|
Key string `json:"key"`
|
|
}
|
|
|
|
func (h *Handlers) ActivateMockDatabase(c *gin.Context) {
|
|
userID := auth.UserID(c)
|
|
var req activateMockReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"})
|
|
return
|
|
}
|
|
key := strings.TrimSpace(req.Key)
|
|
if key != "shop" && key != "hr" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "key 不合法"})
|
|
return
|
|
}
|
|
if err := h.activateMockForUser(userID, key, time.Now().UTC().Format(time.RFC3339)); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "设置失败"})
|
|
return
|
|
}
|
|
c.Status(http.StatusNoContent)
|
|
}
|
|
|
|
func (h *Handlers) activateMockForUser(userID string, key string, now string) error {
|
|
ms := mockdata.List()
|
|
var picked *mockdata.MockDatabase
|
|
for i := range ms {
|
|
if ms[i].Key == key {
|
|
picked = &ms[i]
|
|
break
|
|
}
|
|
}
|
|
if picked == nil {
|
|
return sql.ErrNoRows
|
|
}
|
|
tx, err := h.sqlite.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil {
|
|
return err
|
|
}
|
|
var existingID string
|
|
err = tx.QueryRow(
|
|
`SELECT id FROM user_databases WHERE user_id = ? AND source = 'mock' AND schema_name = ? LIMIT 1`,
|
|
userID,
|
|
picked.SchemaName,
|
|
).Scan(&existingID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
newID := uuid.NewString()
|
|
if _, err := tx.Exec(
|
|
`INSERT INTO user_databases (id, user_id, name, source, schema_name, is_active, created_at) VALUES (?, ?, ?, 'mock', ?, 1, ?)`,
|
|
newID,
|
|
userID,
|
|
picked.Name,
|
|
picked.SchemaName,
|
|
now,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
} else {
|
|
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 1 WHERE id = ?`, existingID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (h *Handlers) ImportDatabase(c *gin.Context) {
|
|
userID := auth.UserID(c)
|
|
name := strings.TrimSpace(c.PostForm("name"))
|
|
if name == "" {
|
|
name = "我的数据库"
|
|
}
|
|
file, hdr, err := c.Request.FormFile("file")
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "请选择文件"})
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
if strings.ToLower(filepath.Ext(hdr.Filename)) != ".sql" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "仅支持 .sql 文件"})
|
|
return
|
|
}
|
|
|
|
buf := new(bytes.Buffer)
|
|
if _, err := io.CopyN(buf, file, 2<<20); err != nil && err != io.EOF {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "读取失败"})
|
|
return
|
|
}
|
|
sqlText := buf.String()
|
|
if strings.TrimSpace(sqlText) == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "SQL 为空"})
|
|
return
|
|
}
|
|
|
|
schema := makeSchemaName(userID)
|
|
if err := createSchemaAndInit(h.mysql, schema, sqlText); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "导入失败:" + err.Error()})
|
|
return
|
|
}
|
|
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
tx, err := h.sqlite.Begin()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
if _, err := tx.Exec(`UPDATE user_databases SET is_active = 0 WHERE user_id = ?`, userID); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
id := uuid.NewString()
|
|
if _, err := tx.Exec(
|
|
`INSERT INTO user_databases (id, user_id, name, source, schema_name, is_active, created_at)
|
|
VALUES (?, ?, ?, 'imported', ?, 1, ?)`,
|
|
id,
|
|
userID,
|
|
name,
|
|
schema,
|
|
now,
|
|
); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "保存失败"})
|
|
return
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "服务异常"})
|
|
return
|
|
}
|
|
c.Status(http.StatusNoContent)
|
|
}
|
|
|
|
var reSchemaSan = regexp.MustCompile(`[^a-z0-9_]+`)
|
|
|
|
func makeSchemaName(userID string) string {
|
|
base := strings.ToLower(userID)
|
|
base = strings.ReplaceAll(base, "-", "_")
|
|
base = reSchemaSan.ReplaceAllString(base, "")
|
|
if len(base) > 12 {
|
|
base = base[:12]
|
|
}
|
|
return "udb_" + base + "_" + strings.ReplaceAll(uuid.NewString(), "-", "")[:8]
|
|
}
|
|
|
|
func quoteIdent(s string) string {
|
|
return "`" + strings.ReplaceAll(s, "`", "``") + "`"
|
|
}
|
|
|
|
func createSchemaAndInit(mysql *sql.DB, schema string, initSQL string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := mysql.Conn(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
|
|
if _, err := conn.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+quoteIdent(schema)+" CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci"); err != nil {
|
|
return err
|
|
}
|
|
if _, err := conn.ExecContext(ctx, "USE "+quoteIdent(schema)); err != nil {
|
|
return err
|
|
}
|
|
stmts := splitStatements(initSQL)
|
|
for _, s := range stmts {
|
|
if _, err := conn.ExecContext(ctx, s); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func splitStatements(sqlText string) []string {
|
|
parts := strings.Split(sqlText, ";")
|
|
out := make([]string, 0, len(parts))
|
|
for _, p := range parts {
|
|
s := strings.TrimSpace(p)
|
|
if s == "" {
|
|
continue
|
|
}
|
|
out = append(out, s)
|
|
}
|
|
return out
|
|
}
|