| from __future__ import annotations |
|
|
| import json |
| import sqlite3 |
| import uuid |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| TERMINAL_TASK_STATUSES = {'success', 'failed', 'stopped', 'interrupted'} |
| ACTIVE_TASK_STATUSES = {'queued', 'running', 'waiting_captcha'} |
|
|
|
|
| def utc_now() -> str: |
| return datetime.now(timezone.utc).isoformat(timespec='seconds') |
|
|
|
|
| class Database: |
| def __init__(self, db_path: Path): |
| self.db_path = Path(db_path) |
| self.db_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| def _connect(self) -> sqlite3.Connection: |
| connection = sqlite3.connect(self.db_path, check_same_thread=False) |
| connection.row_factory = sqlite3.Row |
| connection.execute('PRAGMA foreign_keys = ON') |
| return connection |
|
|
| def initialize(self) -> None: |
| with self._connect() as connection: |
| connection.execute('PRAGMA journal_mode = WAL') |
| connection.executescript( |
| """ |
| CREATE TABLE IF NOT EXISTS users ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| student_id TEXT NOT NULL UNIQUE, |
| display_name TEXT NOT NULL DEFAULT '', |
| encrypted_password TEXT NOT NULL, |
| created_at TEXT NOT NULL, |
| updated_at TEXT NOT NULL |
| ); |
| |
| CREATE TABLE IF NOT EXISTS admins ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| username TEXT NOT NULL UNIQUE, |
| password_hash TEXT NOT NULL, |
| created_at TEXT NOT NULL, |
| updated_at TEXT NOT NULL |
| ); |
| |
| CREATE TABLE IF NOT EXISTS courses ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| user_id INTEGER NOT NULL, |
| course_id TEXT NOT NULL, |
| course_index TEXT NOT NULL, |
| created_at TEXT NOT NULL, |
| UNIQUE(user_id, course_id, course_index), |
| FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE |
| ); |
| |
| CREATE TABLE IF NOT EXISTS tasks ( |
| id TEXT PRIMARY KEY, |
| user_id INTEGER NOT NULL, |
| status TEXT NOT NULL, |
| created_at TEXT NOT NULL, |
| started_at TEXT, |
| finished_at TEXT, |
| requested_by_role TEXT NOT NULL, |
| requested_by_identity TEXT NOT NULL, |
| stop_requested INTEGER NOT NULL DEFAULT 0, |
| last_error TEXT NOT NULL DEFAULT '', |
| total_count INTEGER NOT NULL DEFAULT 0, |
| completed_count INTEGER NOT NULL DEFAULT 0, |
| task_payload TEXT NOT NULL, |
| FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE |
| ); |
| |
| CREATE TABLE IF NOT EXISTS task_logs ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| task_id TEXT NOT NULL, |
| level TEXT NOT NULL, |
| message TEXT NOT NULL, |
| created_at TEXT NOT NULL, |
| FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE |
| ); |
| |
| CREATE TABLE IF NOT EXISTS settings ( |
| key TEXT PRIMARY KEY, |
| value TEXT NOT NULL |
| ); |
| |
| CREATE INDEX IF NOT EXISTS idx_courses_user_id ON courses(user_id); |
| CREATE INDEX IF NOT EXISTS idx_tasks_user_id ON tasks(user_id); |
| CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); |
| CREATE INDEX IF NOT EXISTS idx_task_logs_task_id ON task_logs(task_id, id); |
| """ |
| ) |
| connection.execute( |
| 'INSERT OR IGNORE INTO settings(key, value) VALUES (?, ?)', |
| ('max_parallel_tasks', '2'), |
| ) |
| connection.execute( |
| """ |
| UPDATE tasks |
| SET status = 'interrupted', |
| finished_at = ?, |
| last_error = CASE |
| WHEN COALESCE(last_error, '') = '' THEN '应用重启,上一轮任务被中断。' |
| ELSE last_error |
| END |
| WHERE status IN ('queued', 'running', 'waiting_captcha') |
| """, |
| (utc_now(),), |
| ) |
|
|
| def get_setting(self, key: str, default: str = '') -> str: |
| with self._connect() as connection: |
| row = connection.execute('SELECT value FROM settings WHERE key = ?', (key,)).fetchone() |
| return row['value'] if row else default |
|
|
| def set_setting(self, key: str, value: str) -> None: |
| with self._connect() as connection: |
| connection.execute( |
| """ |
| INSERT INTO settings(key, value) VALUES (?, ?) |
| ON CONFLICT(key) DO UPDATE SET value = excluded.value |
| """, |
| (key, value), |
| ) |
|
|
| def get_max_parallel_tasks(self) -> int: |
| raw = self.get_setting('max_parallel_tasks', '2') |
| try: |
| return max(1, min(8, int(raw))) |
| except ValueError: |
| return 2 |
|
|
| def create_user(self, student_id: str, encrypted_password: str, display_name: str = '') -> int: |
| now = utc_now() |
| with self._connect() as connection: |
| cursor = connection.execute( |
| """ |
| INSERT INTO users(student_id, display_name, encrypted_password, created_at, updated_at) |
| VALUES (?, ?, ?, ?, ?) |
| """, |
| (student_id, display_name, encrypted_password, now, now), |
| ) |
| return int(cursor.lastrowid) |
|
|
| def update_user( |
| self, |
| user_id: int, |
| *, |
| student_id: str | None = None, |
| encrypted_password: str | None = None, |
| display_name: str | None = None, |
| ) -> None: |
| fields: list[str] = [] |
| values: list[Any] = [] |
| if student_id is not None: |
| fields.append('student_id = ?') |
| values.append(student_id) |
| if encrypted_password is not None: |
| fields.append('encrypted_password = ?') |
| values.append(encrypted_password) |
| if display_name is not None: |
| fields.append('display_name = ?') |
| values.append(display_name) |
| if not fields: |
| return |
| fields.append('updated_at = ?') |
| values.append(utc_now()) |
| values.append(user_id) |
| with self._connect() as connection: |
| connection.execute(f"UPDATE users SET {', '.join(fields)} WHERE id = ?", values) |
|
|
| def delete_user(self, user_id: int) -> None: |
| with self._connect() as connection: |
| connection.execute('DELETE FROM users WHERE id = ?', (user_id,)) |
|
|
| def get_user_by_student_id(self, student_id: str) -> dict[str, Any] | None: |
| with self._connect() as connection: |
| row = connection.execute('SELECT * FROM users WHERE student_id = ?', (student_id,)).fetchone() |
| return dict(row) if row else None |
|
|
| def get_user_by_id(self, user_id: int) -> dict[str, Any] | None: |
| with self._connect() as connection: |
| row = connection.execute('SELECT * FROM users WHERE id = ?', (user_id,)).fetchone() |
| return dict(row) if row else None |
|
|
| def list_users(self) -> list[dict[str, Any]]: |
| with self._connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT |
| u.*, |
| COUNT(c.id) AS course_count, |
| ( |
| SELECT t.status |
| FROM tasks t |
| WHERE t.user_id = u.id |
| ORDER BY t.created_at DESC |
| LIMIT 1 |
| ) AS latest_task_status |
| FROM users u |
| LEFT JOIN courses c ON c.user_id = u.id |
| GROUP BY u.id |
| ORDER BY u.created_at ASC |
| """ |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def list_courses_for_user(self, user_id: int) -> list[dict[str, Any]]: |
| with self._connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT id, course_id, course_index, created_at |
| FROM courses |
| WHERE user_id = ? |
| ORDER BY created_at ASC, id ASC |
| """, |
| (user_id,), |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def add_course(self, user_id: int, course_id: str, course_index: str) -> None: |
| with self._connect() as connection: |
| connection.execute( |
| """ |
| INSERT OR IGNORE INTO courses(user_id, course_id, course_index, created_at) |
| VALUES (?, ?, ?, ?) |
| """, |
| (user_id, course_id, course_index, utc_now()), |
| ) |
|
|
| def delete_course(self, course_row_id: int, user_id: int | None = None) -> None: |
| with self._connect() as connection: |
| if user_id is None: |
| connection.execute('DELETE FROM courses WHERE id = ?', (course_row_id,)) |
| else: |
| connection.execute('DELETE FROM courses WHERE id = ? AND user_id = ?', (course_row_id, user_id)) |
|
|
| def list_admins(self) -> list[dict[str, Any]]: |
| with self._connect() as connection: |
| rows = connection.execute('SELECT id, username, created_at, updated_at FROM admins ORDER BY created_at ASC').fetchall() |
| return [dict(row) for row in rows] |
|
|
| def get_admin_by_username(self, username: str) -> dict[str, Any] | None: |
| with self._connect() as connection: |
| row = connection.execute('SELECT * FROM admins WHERE username = ?', (username,)).fetchone() |
| return dict(row) if row else None |
|
|
| def create_admin(self, username: str, password_hash: str) -> int: |
| now = utc_now() |
| with self._connect() as connection: |
| cursor = connection.execute( |
| """ |
| INSERT INTO admins(username, password_hash, created_at, updated_at) |
| VALUES (?, ?, ?, ?) |
| """, |
| (username, password_hash, now, now), |
| ) |
| return int(cursor.lastrowid) |
|
|
| def update_admin_password(self, admin_id: int, password_hash: str) -> None: |
| with self._connect() as connection: |
| connection.execute( |
| 'UPDATE admins SET password_hash = ?, updated_at = ? WHERE id = ?', |
| (password_hash, utc_now(), admin_id), |
| ) |
|
|
| def delete_admin(self, admin_id: int) -> None: |
| with self._connect() as connection: |
| connection.execute('DELETE FROM admins WHERE id = ?', (admin_id,)) |
|
|
| def find_active_task_for_user(self, user_id: int) -> dict[str, Any] | None: |
| with self._connect() as connection: |
| row = connection.execute( |
| """ |
| SELECT * FROM tasks |
| WHERE user_id = ? AND status IN ('queued', 'running', 'waiting_captcha') |
| ORDER BY created_at DESC |
| LIMIT 1 |
| """, |
| (user_id,), |
| ).fetchone() |
| return dict(row) if row else None |
|
|
| def create_task( |
| self, |
| *, |
| user_id: int, |
| requested_by_role: str, |
| requested_by_identity: str, |
| payload: dict[str, Any], |
| ) -> str: |
| task_id = str(uuid.uuid4()) |
| now = utc_now() |
| with self._connect() as connection: |
| connection.execute( |
| """ |
| INSERT INTO tasks( |
| id, |
| user_id, |
| status, |
| created_at, |
| requested_by_role, |
| requested_by_identity, |
| total_count, |
| completed_count, |
| task_payload |
| ) |
| VALUES (?, ?, 'queued', ?, ?, ?, ?, 0, ?) |
| """, |
| ( |
| task_id, |
| user_id, |
| now, |
| requested_by_role, |
| requested_by_identity, |
| len(payload.get('courses', [])), |
| json.dumps(payload, ensure_ascii=False), |
| ), |
| ) |
| return task_id |
|
|
| def claim_next_queued_task(self) -> dict[str, Any] | None: |
| with self._connect() as connection: |
| row = connection.execute("SELECT id FROM tasks WHERE status = 'queued' ORDER BY created_at ASC LIMIT 1").fetchone() |
| if not row: |
| return None |
| updated = connection.execute( |
| """ |
| UPDATE tasks |
| SET status = 'running', started_at = ?, last_error = '' |
| WHERE id = ? AND status = 'queued' |
| """, |
| (utc_now(), row['id']), |
| ).rowcount |
| if not updated: |
| return None |
| claimed = connection.execute('SELECT * FROM tasks WHERE id = ?', (row['id'],)).fetchone() |
| return dict(claimed) if claimed else None |
|
|
| def get_task(self, task_id: str) -> dict[str, Any] | None: |
| with self._connect() as connection: |
| row = connection.execute('SELECT * FROM tasks WHERE id = ?', (task_id,)).fetchone() |
| return dict(row) if row else None |
|
|
| def get_task_with_user(self, task_id: str) -> dict[str, Any] | None: |
| with self._connect() as connection: |
| row = connection.execute( |
| """ |
| SELECT |
| t.*, |
| u.student_id, |
| u.display_name |
| FROM tasks t |
| JOIN users u ON u.id = t.user_id |
| WHERE t.id = ? |
| """, |
| (task_id,), |
| ).fetchone() |
| return dict(row) if row else None |
|
|
| def list_recent_tasks_for_user(self, user_id: int, limit: int = 12) -> list[dict[str, Any]]: |
| with self._connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT id, user_id, status, created_at, started_at, finished_at, stop_requested, |
| last_error, total_count, completed_count |
| FROM tasks |
| WHERE user_id = ? |
| ORDER BY created_at DESC |
| LIMIT ? |
| """, |
| (user_id, limit), |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def list_recent_tasks(self, limit: int = 20) -> list[dict[str, Any]]: |
| with self._connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT |
| t.id, |
| t.user_id, |
| t.status, |
| t.created_at, |
| t.started_at, |
| t.finished_at, |
| t.stop_requested, |
| t.last_error, |
| t.total_count, |
| t.completed_count, |
| u.student_id, |
| u.display_name |
| FROM tasks t |
| JOIN users u ON u.id = t.user_id |
| ORDER BY t.created_at DESC |
| LIMIT ? |
| """, |
| (limit,), |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def set_task_status( |
| self, |
| task_id: str, |
| status: str, |
| *, |
| last_error: str | None = None, |
| completed_count: int | None = None, |
| ) -> None: |
| assignments = ['status = ?'] |
| values: list[Any] = [status] |
| if last_error is not None: |
| assignments.append('last_error = ?') |
| values.append(last_error) |
| if completed_count is not None: |
| assignments.append('completed_count = ?') |
| values.append(completed_count) |
| if status == 'running': |
| assignments.append('started_at = COALESCE(started_at, ?)') |
| values.append(utc_now()) |
| if status in TERMINAL_TASK_STATUSES: |
| assignments.append('finished_at = ?') |
| values.append(utc_now()) |
| values.append(task_id) |
| with self._connect() as connection: |
| connection.execute(f"UPDATE tasks SET {', '.join(assignments)} WHERE id = ?", values) |
|
|
| def update_task_progress(self, task_id: str, completed_count: int) -> None: |
| with self._connect() as connection: |
| connection.execute('UPDATE tasks SET completed_count = ? WHERE id = ?', (completed_count, task_id)) |
|
|
| def request_task_stop(self, task_id: str) -> None: |
| with self._connect() as connection: |
| connection.execute('UPDATE tasks SET stop_requested = 1 WHERE id = ?', (task_id,)) |
|
|
| def is_task_stop_requested(self, task_id: str) -> bool: |
| with self._connect() as connection: |
| row = connection.execute('SELECT stop_requested FROM tasks WHERE id = ?', (task_id,)).fetchone() |
| return bool(row and row['stop_requested']) |
|
|
| def append_task_log(self, task_id: str, level: str, message: str) -> int: |
| with self._connect() as connection: |
| cursor = connection.execute( |
| """ |
| INSERT INTO task_logs(task_id, level, message, created_at) |
| VALUES (?, ?, ?, ?) |
| """, |
| (task_id, level, message, utc_now()), |
| ) |
| return int(cursor.lastrowid) |
|
|
| def list_task_logs(self, task_id: str, after_id: int = 0, limit: int = 200) -> list[dict[str, Any]]: |
| with self._connect() as connection: |
| rows = connection.execute( |
| """ |
| SELECT id, level, message, created_at |
| FROM task_logs |
| WHERE task_id = ? AND id > ? |
| ORDER BY id ASC |
| LIMIT ? |
| """, |
| (task_id, after_id, limit), |
| ).fetchall() |
| return [dict(row) for row in rows] |
|
|
| def get_system_snapshot(self) -> dict[str, Any]: |
| with self._connect() as connection: |
| return { |
| 'users': connection.execute('SELECT COUNT(*) FROM users').fetchone()[0], |
| 'admins': connection.execute('SELECT COUNT(*) FROM admins').fetchone()[0], |
| 'queued': connection.execute("SELECT COUNT(*) FROM tasks WHERE status = 'queued'").fetchone()[0], |
| 'running': connection.execute( |
| "SELECT COUNT(*) FROM tasks WHERE status IN ('running', 'waiting_captcha')" |
| ).fetchone()[0], |
| 'max_parallel_tasks': self.get_max_parallel_tasks(), |
| }
|
|
|