192 lines
4.8 KiB
Go
192 lines
4.8 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"query-database/api/internal/auth"
|
|
"query-database/api/internal/judge"
|
|
)
|
|
|
|
type executeReq struct {
|
|
ExerciseID string `json:"exerciseId"`
|
|
SQL string `json:"sql"`
|
|
}
|
|
|
|
type executeResp struct {
|
|
Ok bool `json:"ok"`
|
|
DurationMs int64 `json:"durationMs"`
|
|
Columns []string `json:"columns"`
|
|
Rows []map[string]any `json:"rows"`
|
|
Verdict string `json:"verdict"`
|
|
Hint string `json:"hint"`
|
|
}
|
|
|
|
func (h *Handlers) ExecuteSQL(c *gin.Context) {
|
|
userID := auth.UserID(c)
|
|
var req executeReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "参数错误"})
|
|
return
|
|
}
|
|
sqlText := strings.TrimSpace(req.SQL)
|
|
if req.ExerciseID == "" || sqlText == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "exerciseId 或 sql 不能为空"})
|
|
return
|
|
}
|
|
if err := validateUserQuery(sqlText); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
|
return
|
|
}
|
|
|
|
active, err := h.getActiveDatabase(userID)
|
|
if err != nil || active == nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "请先在数据库管理页激活一个数据库"})
|
|
return
|
|
}
|
|
|
|
var answerSQL string
|
|
var dbKey string
|
|
if err := h.sqlite.QueryRow(`SELECT answer_sql, database_key FROM exercises WHERE id = ?`, req.ExerciseID).Scan(&answerSQL, &dbKey); err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"message": "题目不存在"})
|
|
return
|
|
}
|
|
mod, ok := moduleDatabaseByKey(dbKey)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "题目数据库配置错误"})
|
|
return
|
|
}
|
|
|
|
if active.Source != "mock" || active.SchemaName != mod.SchemaName {
|
|
c.JSON(http.StatusOK, executeResp{
|
|
Ok: false,
|
|
DurationMs: 0,
|
|
Columns: []string{},
|
|
Rows: []map[string]any{},
|
|
Verdict: "fail",
|
|
Hint: "当前激活数据库与本题不匹配,请切换到:" + mod.Name,
|
|
})
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := h.mysql.Conn(ctx)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "MySQL 连接失败"})
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
if _, err := conn.ExecContext(ctx, "USE "+quoteIdent(active.SchemaName)); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": "数据库不可用"})
|
|
return
|
|
}
|
|
|
|
started := time.Now()
|
|
userRes, err := runQuery(ctx, conn, sqlText)
|
|
dur := time.Since(started).Milliseconds()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, executeResp{Ok: false, DurationMs: dur, Verdict: "fail", Hint: "SQL 执行失败:" + err.Error(), Columns: []string{}, Rows: []map[string]any{}})
|
|
return
|
|
}
|
|
|
|
ansRes, err := runQuery(ctx, conn, answerSQL)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "标准答案执行失败"})
|
|
return
|
|
}
|
|
|
|
v := judge.Compare(ansRes, userRes)
|
|
verdict := "fail"
|
|
if v.Pass {
|
|
verdict = "pass"
|
|
}
|
|
|
|
c.JSON(http.StatusOK, executeResp{
|
|
Ok: true,
|
|
DurationMs: dur,
|
|
Columns: userRes.Columns,
|
|
Rows: toRowMaps(userRes),
|
|
Verdict: verdict,
|
|
Hint: v.Hint,
|
|
})
|
|
}
|
|
|
|
func validateUserQuery(sqlText string) error {
|
|
parts := splitStatements(sqlText)
|
|
if len(parts) != 1 {
|
|
return errString("一次只能执行一条语句")
|
|
}
|
|
s := strings.ToLower(strings.TrimSpace(parts[0]))
|
|
if strings.HasPrefix(s, "select ") || s == "select" {
|
|
return nil
|
|
}
|
|
if strings.HasPrefix(s, "with ") {
|
|
return nil
|
|
}
|
|
if strings.HasPrefix(s, "show ") || strings.HasPrefix(s, "describe ") || strings.HasPrefix(s, "explain ") {
|
|
return nil
|
|
}
|
|
return errString("仅允许 SELECT/WITH/SHOW/DESCRIBE/EXPLAIN")
|
|
}
|
|
|
|
type errString string
|
|
|
|
func (e errString) Error() string { return string(e) }
|
|
|
|
func runQuery(ctx context.Context, conn *sql.Conn, sqlText string) (judge.QueryResult, error) {
|
|
rows, err := conn.QueryContext(ctx, sqlText)
|
|
if err != nil {
|
|
return judge.QueryResult{}, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return judge.QueryResult{}, err
|
|
}
|
|
|
|
out := judge.QueryResult{Columns: cols, Rows: make([][]any, 0)}
|
|
for rows.Next() {
|
|
values := make([]any, len(cols))
|
|
ptrs := make([]any, len(cols))
|
|
for i := range values {
|
|
ptrs[i] = &values[i]
|
|
}
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return judge.QueryResult{}, err
|
|
}
|
|
for i := range values {
|
|
if b, ok := values[i].([]byte); ok {
|
|
values[i] = string(b)
|
|
}
|
|
}
|
|
out.Rows = append(out.Rows, values)
|
|
if len(out.Rows) >= 200 {
|
|
break
|
|
}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func toRowMaps(r judge.QueryResult) []map[string]any {
|
|
out := make([]map[string]any, 0, len(r.Rows))
|
|
for _, row := range r.Rows {
|
|
m := make(map[string]any, len(r.Columns))
|
|
for i, c := range r.Columns {
|
|
if i < len(row) {
|
|
m[c] = row[i]
|
|
}
|
|
}
|
|
out = append(out, m)
|
|
}
|
|
return out
|
|
}
|