174 lines
5.9 KiB
Go
174 lines
5.9 KiB
Go
package db
|
||
|
||
import (
|
||
"database/sql"
|
||
"os"
|
||
"path/filepath"
|
||
"time"
|
||
|
||
_ "modernc.org/sqlite"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
func OpenSQLite(path string) (*sql.DB, error) {
|
||
dir := filepath.Dir(path)
|
||
if dir != "." {
|
||
_ = os.MkdirAll(dir, 0o755)
|
||
}
|
||
db, err := sql.Open("sqlite", path)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
db.SetMaxOpenConns(1)
|
||
db.SetConnMaxLifetime(10 * time.Minute)
|
||
if err := db.Ping(); err != nil {
|
||
return nil, err
|
||
}
|
||
return db, nil
|
||
}
|
||
|
||
func MigrateSQLite(db *sql.DB) error {
|
||
stmts := []string{
|
||
`CREATE TABLE IF NOT EXISTS users (
|
||
id TEXT PRIMARY KEY,
|
||
email TEXT NOT NULL UNIQUE,
|
||
password_hash TEXT NOT NULL,
|
||
name TEXT NOT NULL DEFAULT '',
|
||
module_key TEXT NOT NULL DEFAULT 'shop',
|
||
experience_level TEXT NOT NULL DEFAULT 'beginner',
|
||
onboarding_completed INTEGER NOT NULL DEFAULT 0,
|
||
created_at TEXT NOT NULL
|
||
);`,
|
||
`CREATE TABLE IF NOT EXISTS user_databases (
|
||
id TEXT PRIMARY KEY,
|
||
user_id TEXT NOT NULL,
|
||
name TEXT NOT NULL,
|
||
source TEXT NOT NULL,
|
||
schema_name TEXT NOT NULL,
|
||
is_active INTEGER NOT NULL DEFAULT 0,
|
||
created_at TEXT NOT NULL
|
||
);`,
|
||
`CREATE INDEX IF NOT EXISTS idx_user_databases_user_id ON user_databases(user_id);`,
|
||
`CREATE TABLE IF NOT EXISTS exercises (
|
||
id TEXT PRIMARY KEY,
|
||
title TEXT NOT NULL,
|
||
level TEXT NOT NULL,
|
||
prompt TEXT NOT NULL,
|
||
answer_sql TEXT NOT NULL,
|
||
database_key TEXT NOT NULL,
|
||
created_at TEXT NOT NULL
|
||
);`,
|
||
`CREATE INDEX IF NOT EXISTS idx_exercises_level ON exercises(level);`,
|
||
`CREATE TABLE IF NOT EXISTS progress (
|
||
id TEXT PRIMARY KEY,
|
||
user_id TEXT NOT NULL,
|
||
exercise_id TEXT NOT NULL,
|
||
draft_sql TEXT NOT NULL DEFAULT '',
|
||
is_solved INTEGER NOT NULL DEFAULT 0,
|
||
updated_at TEXT NOT NULL
|
||
);`,
|
||
`CREATE UNIQUE INDEX IF NOT EXISTS uniq_progress_user_exercise ON progress(user_id, exercise_id);`,
|
||
}
|
||
|
||
for _, s := range stmts {
|
||
if _, err := db.Exec(s); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func SeedSQLite(db *sql.DB) error {
|
||
var cnt int
|
||
if err := db.QueryRow(`SELECT COUNT(1) FROM exercises`).Scan(&cnt); err != nil {
|
||
return err
|
||
}
|
||
if cnt > 0 {
|
||
return nil
|
||
}
|
||
|
||
now := time.Now().UTC().Format(time.RFC3339)
|
||
seed := []struct {
|
||
Title string
|
||
Level string
|
||
Prompt string
|
||
AnswerSQL string
|
||
DatabaseKey string
|
||
}{
|
||
{
|
||
Title: "新手 1:查询所有商品名称与价格",
|
||
Level: "beginner",
|
||
Prompt: "在电商库中查询 products 表,返回 name 与 price 两列。",
|
||
AnswerSQL: "SELECT name, price FROM products ORDER BY id;",
|
||
DatabaseKey: "shop",
|
||
},
|
||
{
|
||
Title: "新手 2:筛选价格大于 100 的商品",
|
||
Level: "beginner",
|
||
Prompt: "在电商库中查询 price > 100 的商品,返回 id, name, price,并按 price 从高到低排序。",
|
||
AnswerSQL: "SELECT id, name, price FROM products WHERE price > 100 ORDER BY price DESC, id ASC;",
|
||
DatabaseKey: "shop",
|
||
},
|
||
{
|
||
Title: "一般 1:统计每个用户的订单数",
|
||
Level: "normal",
|
||
Prompt: "在电商库中统计每个用户的订单数量,返回 user_id 与 order_count,并按 order_count 从高到低排序。",
|
||
AnswerSQL: "SELECT user_id, COUNT(*) AS order_count FROM orders GROUP BY user_id ORDER BY order_count DESC, user_id ASC;",
|
||
DatabaseKey: "shop",
|
||
},
|
||
{
|
||
Title: "一般 2:查询每个订单的总金额",
|
||
Level: "normal",
|
||
Prompt: "在电商库中,计算每个订单的总金额(sum(quantity * unit_price)),返回 order_id 与 total_amount,按 order_id 排序。",
|
||
AnswerSQL: "SELECT order_id, SUM(quantity * unit_price) AS total_amount FROM order_items GROUP BY order_id ORDER BY order_id ASC;",
|
||
DatabaseKey: "shop",
|
||
},
|
||
{
|
||
Title: "进阶 1:找出下单金额最高的用户",
|
||
Level: "advanced",
|
||
Prompt: "在电商库中,计算每个用户的下单总金额,找出总金额最高的用户,返回 user_id 与 total_amount。",
|
||
AnswerSQL: "SELECT o.user_id, SUM(oi.quantity * oi.unit_price) AS total_amount FROM orders o JOIN order_items oi ON oi.order_id = o.id GROUP BY o.user_id ORDER BY total_amount DESC, o.user_id ASC LIMIT 1;",
|
||
DatabaseKey: "shop",
|
||
},
|
||
{
|
||
Title: "新手 3:查询所有员工姓名与部门",
|
||
Level: "beginner",
|
||
Prompt: "在人事库中查询 employees 与 departments,返回 employee_name 与 department_name,并按 employee_id 排序。",
|
||
AnswerSQL: "SELECT e.name AS employee_name, d.name AS department_name FROM employees e JOIN departments d ON d.id = e.department_id ORDER BY e.id ASC;",
|
||
DatabaseKey: "hr",
|
||
},
|
||
{
|
||
Title: "一般 3:统计每个部门员工数",
|
||
Level: "normal",
|
||
Prompt: "在人事库中统计每个部门的员工数,返回 department_name 与 employee_count,按 employee_count 从高到低排序。",
|
||
AnswerSQL: "SELECT d.name AS department_name, COUNT(e.id) AS employee_count FROM departments d LEFT JOIN employees e ON e.department_id = d.id GROUP BY d.id ORDER BY employee_count DESC, d.id ASC;",
|
||
DatabaseKey: "hr",
|
||
},
|
||
{
|
||
Title: "进阶 2:找出每个部门工资最高的员工",
|
||
Level: "advanced",
|
||
Prompt: "在人事库中,找出每个部门工资最高的员工,返回 department_name, employee_name, salary。",
|
||
AnswerSQL: "SELECT d.name AS department_name, e.name AS employee_name, e.salary FROM departments d JOIN employees e ON e.department_id = d.id WHERE e.salary = (SELECT MAX(salary) FROM employees e2 WHERE e2.department_id = d.id) ORDER BY d.id ASC, e.id ASC;",
|
||
DatabaseKey: "hr",
|
||
},
|
||
}
|
||
|
||
for _, s := range seed {
|
||
id := uuid.NewString()
|
||
if _, err := db.Exec(
|
||
`INSERT INTO exercises (id, title, level, prompt, answer_sql, database_key, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||
id,
|
||
s.Title,
|
||
s.Level,
|
||
s.Prompt,
|
||
s.AnswerSQL,
|
||
s.DatabaseKey,
|
||
now,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|