Spaces:
Sleeping
Sleeping
| """Knowledge Distillation Engine - Student learns from Teacher""" | |
| from typing import Optional, Tuple, Dict | |
| from teacher import teacher | |
| from database import db | |
| from config import ( | |
| DISTILLATION_ENABLED, | |
| AUTO_LEARN_FROM_TEACHER, | |
| MIN_RESPONSE_LENGTH, | |
| MIN_SAMPLES_FOR_DISTILL_TRAINING, | |
| ) | |
| class DistillationEngine: | |
| """ | |
| Manages knowledge distillation from teacher to student. | |
| Modes: | |
| 1. AUTO: Always ask teacher, save responses, use student for speed | |
| 2. FALLBACK: Only ask teacher if student response is poor | |
| 3. COMPARE: Show both responses for comparison | |
| """ | |
| def __init__(self): | |
| self.teacher = teacher | |
| self.mode = "fallback" # "auto", "fallback", "compare" | |
| self.teacher_call_count = 0 | |
| self.student_call_count = 0 | |
| def should_ask_teacher(self, student_response: str) -> bool: | |
| """Decide if we should ask the teacher based on student response quality""" | |
| if not DISTILLATION_ENABLED: | |
| return False | |
| if not self.teacher.is_available(): | |
| return False | |
| # Heuristics for low-quality response | |
| if not student_response: | |
| return True | |
| if len(student_response.strip()) < MIN_RESPONSE_LENGTH: | |
| return True | |
| # Check for error messages | |
| low_quality_indicators = [ | |
| "I'm not sure", | |
| "I don't know", | |
| "Could you try rephrasing", | |
| "Error:", | |
| "not sure how to respond", | |
| ] | |
| for indicator in low_quality_indicators: | |
| if indicator.lower() in student_response.lower(): | |
| return True | |
| return False | |
| def get_teacher_response( | |
| self, | |
| user_input: str, | |
| conversation_history: list = None, | |
| student_response: str = None, | |
| ) -> Optional[str]: | |
| """Get response from teacher and optionally save for training""" | |
| teacher_response = self.teacher.ask( | |
| user_message=user_input, | |
| conversation_history=conversation_history, | |
| ) | |
| if teacher_response and AUTO_LEARN_FROM_TEACHER: | |
| # Save for future training | |
| db.save_distillation_data( | |
| user_input=user_input, | |
| teacher_response=teacher_response, | |
| student_response=student_response, | |
| quality_score=1.0, # Teacher responses are high quality | |
| ) | |
| if teacher_response: | |
| self.teacher_call_count += 1 | |
| return teacher_response | |
| def process_with_distillation( | |
| self, | |
| user_input: str, | |
| student_response: str, | |
| conversation_history: list = None, | |
| ) -> Tuple[str, str]: | |
| """ | |
| Process a response with potential teacher assistance. | |
| Returns: | |
| Tuple of (final_response, source) where source is "student", "teacher", or "both" | |
| """ | |
| self.student_call_count += 1 | |
| if self.mode == "auto": | |
| # Always get teacher response for learning, but return student for speed | |
| teacher_resp = self.get_teacher_response( | |
| user_input, conversation_history, student_response | |
| ) | |
| return student_response, "student" | |
| elif self.mode == "fallback": | |
| # Only ask teacher if student response is poor | |
| if self.should_ask_teacher(student_response): | |
| teacher_resp = self.get_teacher_response( | |
| user_input, conversation_history, student_response | |
| ) | |
| if teacher_resp: | |
| return teacher_resp, "teacher" | |
| return student_response, "student" | |
| elif self.mode == "compare": | |
| # Return both for comparison (useful for debugging/evaluation) | |
| teacher_resp = self.get_teacher_response( | |
| user_input, conversation_history, student_response | |
| ) | |
| if teacher_resp: | |
| combined = f"**🎓 Teacher (Dolphin):**\n{teacher_resp}\n\n---\n\n**🧠 Student (Veda):**\n{student_response}" | |
| return combined, "both" | |
| return student_response, "student" | |
| return student_response, "student" | |
| def set_mode(self, mode: str): | |
| """Set distillation mode: 'auto', 'fallback', or 'compare'""" | |
| if mode in ["auto", "fallback", "compare", "disabled"]: | |
| self.mode = mode | |
| return True | |
| return False | |
| def get_stats(self) -> Dict: | |
| """Get distillation statistics""" | |
| distill_data = db.get_distillation_count() | |
| return { | |
| "mode": self.mode, | |
| "teacher_calls": self.teacher_call_count, | |
| "student_calls": self.student_call_count, | |
| "teacher_available": self.teacher.is_available(), | |
| "distillation_samples": distill_data["total"], | |
| "unused_samples": distill_data["unused"], | |
| "ready_for_training": distill_data["unused"] >= MIN_SAMPLES_FOR_DISTILL_TRAINING, | |
| } | |
| def get_training_data(self) -> str: | |
| """Get accumulated teacher responses as training data""" | |
| unused = db.get_unused_distillation_data() | |
| if not unused: | |
| return "" | |
| training_text = "" | |
| for item in unused: | |
| training_text += f"<USER> {item['user_input']}\n" | |
| training_text += f"<ASSISTANT> {item['teacher_response']}\n\n" | |
| return training_text | |
| def mark_training_complete(self, ids: list): | |
| """Mark distillation data as used after training""" | |
| if ids: | |
| db.mark_distillation_used(ids) | |
| # Global engine instance | |
| distillation_engine = DistillationEngine() |