package handlers import ( "context" "database/sql" "net/http" "strings" "time" "github.com/gin-gonic/gin" "query-database/api/internal/auth" "query-database/api/internal/judge" ) type executeReq struct { ExerciseID string `json:"exerciseId"` SQL string `json:"sql"` } type executeResp struct { Ok bool `json:"ok"` DurationMs int64 `json:"durationMs"` Columns []string `json:"columns"` Rows []map[string]any `json:"rows"` Verdict string `json:"verdict"` Hint string `json:"hint"` } func (h *Handlers) ExecuteSQL(c *gin.Context) { userID := auth.UserID(c) var req executeReq if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"}) return } sqlText := strings.TrimSpace(req.SQL) if req.ExerciseID == "" || sqlText == "" { c.JSON(http.StatusBadRequest, gin.H{"message": "exerciseId 或 sql 不能为空"}) return } if err := validateUserQuery(sqlText); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } active, err := h.getActiveDatabase(userID) if err != nil || active == nil { c.JSON(http.StatusBadRequest, gin.H{"message": "请先在数据库管理页激活一个数据库"}) return } var answerSQL string var dbKey string if err := h.sqlite.QueryRow(`SELECT answer_sql, database_key FROM exercises WHERE id = ?`, req.ExerciseID).Scan(&answerSQL, &dbKey); err != nil { c.JSON(http.StatusNotFound, gin.H{"message": "题目不存在"}) return } mod, ok := moduleDatabaseByKey(dbKey) if !ok { c.JSON(http.StatusInternalServerError, gin.H{"message": "题目数据库配置错误"}) return } if active.Source != "mock" || active.SchemaName != mod.SchemaName { c.JSON(http.StatusOK, executeResp{ Ok: false, DurationMs: 0, Columns: []string{}, Rows: []map[string]any{}, Verdict: "fail", Hint: "当前激活数据库与本题不匹配,请切换到:" + mod.Name, }) return } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := h.mysql.Conn(ctx) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "MySQL 连接失败"}) return } defer conn.Close() if _, err := conn.ExecContext(ctx, "USE "+quoteIdent(active.SchemaName)); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": "数据库不可用"}) return } started := time.Now() userRes, err := runQuery(ctx, conn, sqlText) dur := time.Since(started).Milliseconds() if err != nil { c.JSON(http.StatusOK, executeResp{Ok: false, DurationMs: dur, Verdict: "fail", Hint: "SQL 执行失败:" + err.Error(), Columns: []string{}, Rows: []map[string]any{}}) return } ansRes, err := runQuery(ctx, conn, answerSQL) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "标准答案执行失败"}) return } v := judge.Compare(ansRes, userRes) verdict := "fail" if v.Pass { verdict = "pass" } c.JSON(http.StatusOK, executeResp{ Ok: true, DurationMs: dur, Columns: userRes.Columns, Rows: toRowMaps(userRes), Verdict: verdict, Hint: v.Hint, }) } func validateUserQuery(sqlText string) error { parts := splitStatements(sqlText) if len(parts) != 1 { return errString("一次只能执行一条语句") } s := strings.ToLower(strings.TrimSpace(parts[0])) if strings.HasPrefix(s, "select ") || s == "select" { return nil } if strings.HasPrefix(s, "with ") { return nil } if strings.HasPrefix(s, "show ") || strings.HasPrefix(s, "describe ") || strings.HasPrefix(s, "explain ") { return nil } return errString("仅允许 SELECT/WITH/SHOW/DESCRIBE/EXPLAIN") } type errString string func (e errString) Error() string { return string(e) } func runQuery(ctx context.Context, conn *sql.Conn, sqlText string) (judge.QueryResult, error) { rows, err := conn.QueryContext(ctx, sqlText) if err != nil { return judge.QueryResult{}, err } defer rows.Close() cols, err := rows.Columns() if err != nil { return judge.QueryResult{}, err } out := judge.QueryResult{Columns: cols, Rows: make([][]any, 0)} for rows.Next() { values := make([]any, len(cols)) ptrs := make([]any, len(cols)) for i := range values { ptrs[i] = &values[i] } if err := rows.Scan(ptrs...); err != nil { return judge.QueryResult{}, err } for i := range values { if b, ok := values[i].([]byte); ok { values[i] = string(b) } } out.Rows = append(out.Rows, values) if len(out.Rows) >= 200 { break } } return out, nil } func toRowMaps(r judge.QueryResult) []map[string]any { out := make([]map[string]any, 0, len(r.Rows)) for _, row := range r.Rows { m := make(map[string]any, len(r.Columns)) for i, c := range r.Columns { if i < len(row) { m[c] = row[i] } } out = append(out, m) } return out }