veda-programming / database.py
vedaco's picture
Update database.py
0fe7d00 verified
"""Database for conversations and distillation data"""
import sqlite3
from datetime import datetime
from typing import List, Dict
from config import DATABASE_PATH
class VedaDatabase:
"""Database handler with distillation support"""
def __init__(self):
self._init_db()
def _get_conn(self):
conn = sqlite3.connect(DATABASE_PATH)
conn.row_factory = sqlite3.Row
return conn
def _init_db(self):
conn = self._get_conn()
cursor = conn.cursor()
# Regular conversations table
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
user_input TEXT NOT NULL,
assistant_response TEXT NOT NULL,
feedback INTEGER DEFAULT 0
)
''')
# Distillation data table (teacher responses)
cursor.execute('''
CREATE TABLE IF NOT EXISTS distillation_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
user_input TEXT NOT NULL,
teacher_response TEXT NOT NULL,
student_response TEXT,
used_for_training BOOLEAN DEFAULT 0,
quality_score REAL DEFAULT 0
)
''')
# Training history
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
training_type TEXT,
samples_used INTEGER,
epochs INTEGER,
final_loss REAL
)
''')
conn.commit()
conn.close()
# ===== Conversations =====
def save_conversation(self, user_input: str, response: str) -> int:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
INSERT INTO conversations (user_input, assistant_response)
VALUES (?, ?)
''', (user_input, response))
conv_id = cursor.lastrowid
conn.commit()
conn.close()
return conv_id
def update_feedback(self, conv_id: int, feedback: int):
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
UPDATE conversations SET feedback = ? WHERE id = ?
''', (feedback, conv_id))
conn.commit()
conn.close()
def get_good_conversations(self, limit: int = 100) -> List[Dict]:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
SELECT user_input, assistant_response
FROM conversations
WHERE feedback > 0
ORDER BY timestamp DESC
LIMIT ?
''', (limit,))
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]
# ===== Distillation =====
def save_distillation_data(
self,
user_input: str,
teacher_response: str,
student_response: str = None,
quality_score: float = 0.0
) -> int:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
INSERT INTO distillation_data
(user_input, teacher_response, student_response, quality_score)
VALUES (?, ?, ?, ?)
''', (user_input, teacher_response, student_response, quality_score))
data_id = cursor.lastrowid
conn.commit()
conn.close()
return data_id
def get_unused_distillation_data(self, limit: int = 500) -> List[Dict]:
"""Get teacher responses not yet used for training"""
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
SELECT id, user_input, teacher_response
FROM distillation_data
WHERE used_for_training = 0
ORDER BY timestamp DESC
LIMIT ?
''', (limit,))
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]
def mark_distillation_used(self, ids: List[int]):
"""Mark distillation data as used for training"""
conn = self._get_conn()
cursor = conn.cursor()
placeholders = ",".join("?" * len(ids))
cursor.execute(f'''
UPDATE distillation_data
SET used_for_training = 1
WHERE id IN ({placeholders})
''', ids)
conn.commit()
conn.close()
def get_distillation_count(self) -> Dict:
"""Get count of distillation data"""
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM distillation_data')
total = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 0')
unused = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM distillation_data WHERE used_for_training = 1')
used = cursor.fetchone()[0]
conn.close()
return {"total": total, "unused": unused, "used": used}
# ===== Stats =====
def get_stats(self) -> Dict:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM conversations')
total = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback > 0')
positive = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM conversations WHERE feedback < 0')
negative = cursor.fetchone()[0]
distill = self.get_distillation_count()
conn.close()
return {
"total": total,
"positive": positive,
"negative": negative,
"distillation_total": distill["total"],
"distillation_unused": distill["unused"],
}
def save_training_history(
self,
training_type: str,
samples_used: int,
epochs: int,
final_loss: float
):
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
INSERT INTO training_history (training_type, samples_used, epochs, final_loss)
VALUES (?, ?, ?, ?)
''', (training_type, samples_used, epochs, final_loss))
conn.commit()
conn.close()
db = VedaDatabase()