query-database/api/internal/handlers/sql.go
2026-03-25 15:46:20 +08:00

192 lines
4.8 KiB
Go

package handlers
import (
"context"
"database/sql"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"query-database/api/internal/auth"
"query-database/api/internal/judge"
)
type executeReq struct {
ExerciseID string `json:"exerciseId"`
SQL string `json:"sql"`
}
type executeResp struct {
Ok bool `json:"ok"`
DurationMs int64 `json:"durationMs"`
Columns []string `json:"columns"`
Rows []map[string]any `json:"rows"`
Verdict string `json:"verdict"`
Hint string `json:"hint"`
}
func (h *Handlers) ExecuteSQL(c *gin.Context) {
userID := auth.UserID(c)
var req executeReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"})
return
}
sqlText := strings.TrimSpace(req.SQL)
if req.ExerciseID == "" || sqlText == "" {
c.JSON(http.StatusBadRequest, gin.H{"message": "exerciseId 或 sql 不能为空"})
return
}
if err := validateUserQuery(sqlText); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
active, err := h.getActiveDatabase(userID)
if err != nil || active == nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "请先在数据库管理页激活一个数据库"})
return
}
var answerSQL string
var dbKey string
if err := h.sqlite.QueryRow(`SELECT answer_sql, database_key FROM exercises WHERE id = ?`, req.ExerciseID).Scan(&answerSQL, &dbKey); err != nil {
c.JSON(http.StatusNotFound, gin.H{"message": "题目不存在"})
return
}
mod, ok := moduleDatabaseByKey(dbKey)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"message": "题目数据库配置错误"})
return
}
if active.Source != "mock" || active.SchemaName != mod.SchemaName {
c.JSON(http.StatusOK, executeResp{
Ok: false,
DurationMs: 0,
Columns: []string{},
Rows: []map[string]any{},
Verdict: "fail",
Hint: "当前激活数据库与本题不匹配,请切换到:" + mod.Name,
})
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := h.mysql.Conn(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "MySQL 连接失败"})
return
}
defer conn.Close()
if _, err := conn.ExecContext(ctx, "USE "+quoteIdent(active.SchemaName)); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "数据库不可用"})
return
}
started := time.Now()
userRes, err := runQuery(ctx, conn, sqlText)
dur := time.Since(started).Milliseconds()
if err != nil {
c.JSON(http.StatusOK, executeResp{Ok: false, DurationMs: dur, Verdict: "fail", Hint: "SQL 执行失败:" + err.Error(), Columns: []string{}, Rows: []map[string]any{}})
return
}
ansRes, err := runQuery(ctx, conn, answerSQL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "标准答案执行失败"})
return
}
v := judge.Compare(ansRes, userRes)
verdict := "fail"
if v.Pass {
verdict = "pass"
}
c.JSON(http.StatusOK, executeResp{
Ok: true,
DurationMs: dur,
Columns: userRes.Columns,
Rows: toRowMaps(userRes),
Verdict: verdict,
Hint: v.Hint,
})
}
func validateUserQuery(sqlText string) error {
parts := splitStatements(sqlText)
if len(parts) != 1 {
return errString("一次只能执行一条语句")
}
s := strings.ToLower(strings.TrimSpace(parts[0]))
if strings.HasPrefix(s, "select ") || s == "select" {
return nil
}
if strings.HasPrefix(s, "with ") {
return nil
}
if strings.HasPrefix(s, "show ") || strings.HasPrefix(s, "describe ") || strings.HasPrefix(s, "explain ") {
return nil
}
return errString("仅允许 SELECT/WITH/SHOW/DESCRIBE/EXPLAIN")
}
type errString string
func (e errString) Error() string { return string(e) }
func runQuery(ctx context.Context, conn *sql.Conn, sqlText string) (judge.QueryResult, error) {
rows, err := conn.QueryContext(ctx, sqlText)
if err != nil {
return judge.QueryResult{}, err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return judge.QueryResult{}, err
}
out := judge.QueryResult{Columns: cols, Rows: make([][]any, 0)}
for rows.Next() {
values := make([]any, len(cols))
ptrs := make([]any, len(cols))
for i := range values {
ptrs[i] = &values[i]
}
if err := rows.Scan(ptrs...); err != nil {
return judge.QueryResult{}, err
}
for i := range values {
if b, ok := values[i].([]byte); ok {
values[i] = string(b)
}
}
out.Rows = append(out.Rows, values)
if len(out.Rows) >= 200 {
break
}
}
return out, nil
}
func toRowMaps(r judge.QueryResult) []map[string]any {
out := make([]map[string]any, 0, len(r.Rows))
for _, row := range r.Rows {
m := make(map[string]any, len(r.Columns))
for i, c := range r.Columns {
if i < len(row) {
m[c] = row[i]
}
}
out = append(out, m)
}
return out
}