Spaces:
Running
Running
| """Database for conversations and distillation data""" | |
| import sqlite3 | |
| from datetime import datetime | |
| from typing import List, Dict | |
| from config import DATABASE_PATH | |
| class VedaDatabase: | |
| """Database handler with distillation support""" | |
| def __init__(self): | |
| self._init_db() | |
| def _get_conn(self): | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def _init_db(self): | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| # Regular conversations table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS conversations ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| user_input TEXT NOT NULL, | |
| assistant_response TEXT NOT NULL, | |
| feedback INTEGER DEFAULT 0 | |
| ) | |
| ''') | |
| # Distillation data table (teacher responses) | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS distillation_data ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| user_input TEXT NOT NULL, | |
| teacher_response TEXT NOT NULL, | |
| student_response TEXT, | |
| used_for_training BOOLEAN DEFAULT 0, | |
| quality_score REAL DEFAULT 0 | |
| ) | |
| ''') | |
| # Training history | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS training_history ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| training_type TEXT, | |
| samples_used INTEGER, | |
| epochs INTEGER, | |
| final_loss REAL | |
| ) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| # ===== Conversations ===== | |
| def save_conversation(self, user_input: str, response: str) -> int: | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| INSERT INTO conversations (user_input, assistant_response) | |
| VALUES (?, ?) | |
| ''', (user_input, response)) | |
| conv_id = cursor.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return conv_id | |
| def update_feedback(self, conv_id: int, feedback: int): | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| UPDATE conversations SET feedback = ? WHERE id = ? | |
| ''', (feedback, conv_id)) | |
| conn.commit() | |
| conn.close() | |
| def get_good_conversations(self, limit: int = 100) -> List[Dict]: | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT user_input, assistant_response | |
| FROM conversations | |
| WHERE feedback > 0 | |
| ORDER BY timestamp DESC | |
| LIMIT ? | |
| ''', (limit,)) | |
| rows = cursor.fetchall() | |
| conn.close() | |
| return [dict(row) for row in rows] | |
| # ===== Distillation ===== | |
| def save_distillation_data( | |
| self, | |
| user_input: str, | |
| teacher_response: str, | |
| student_response: str = None, | |
| quality_score: float = 0.0 | |
| ) -> int: | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| INSERT INTO distillation_data | |
| (user_input, teacher_response, student_response, quality_score) | |
| VALUES (?, ?, ?, ?) | |
| ''', (user_input, teacher_response, student_response, quality_score)) | |
| data_id = cursor.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return data_id | |
| def get_unused_distillation_data(self, limit: int = 500) -> List[Dict]: | |
| """Get teacher responses not yet used for training""" | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT id, user_input, teacher_response | |
| FROM distillation_data | |
| WHERE used_for_training = 0 | |
| ORDER BY timestamp DESC | |
| LIMIT ? | |
| ''', (limit,)) | |
| rows = cursor.fetchall() | |
| conn.close() | |
| return [dict(row) for row in rows] | |
| def mark_distillation_used(self, ids: List[int]): | |
| """Mark distillation data as used for training""" | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| placeholders = ",".join("?" * len(ids)) | |
| cursor.execute(f''' | |
| UPDATE distillation_data | |
| SET used_for_training = 1 | |
| WHERE id IN ({placeholders}) | |
| ''', ids) | |
| conn.commit() | |
| conn.close() | |
| def get_distillation_count(self) -> Dict: | |
| """Get count of distillation data""" | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT COUNT(*) FROM distillation_data') | |
| total = cursor.fetchone()[0] | |
| cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 0') | |
| unused = cursor.fetchone()[0] | |
| cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 1') | |
| used = cursor.fetchone()[0] | |
| conn.close() | |
| return {"total": total, "unused": unused, "used": used} | |
| # ===== Stats ===== | |
| def get_stats(self) -> Dict: | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT COUNT(*) FROM conversations') | |
| total = cursor.fetchone()[0] | |
| cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback > 0') | |
| positive = cursor.fetchone()[0] | |
| cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback < 0') | |
| negative = cursor.fetchone()[0] | |
| distill = self.get_distillation_count() | |
| conn.close() | |
| return { | |
| "total": total, | |
| "positive": positive, | |
| "negative": negative, | |
| "distillation_total": distill["total"], | |
| "distillation_unused": distill["unused"], | |
| } | |
| def save_training_history( | |
| self, | |
| training_type: str, | |
| samples_used: int, | |
| epochs: int, | |
| final_loss: float | |
| ): | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| INSERT INTO training_history (training_type, samples_used, epochs, final_loss) | |
| VALUES (?, ?, ?, ?) | |
| ''', (training_type, samples_used, epochs, final_loss)) | |
| conn.commit() | |
| conn.close() | |
| db = VedaDatabase() |