|
|
import os |
|
|
import shutil |
|
|
import time |
|
|
import csv |
|
|
import uuid |
|
|
from itertools import cycle |
|
|
from typing import List, Tuple, Optional |
|
|
from datetime import datetime |
|
|
import gradio as gr |
|
|
|
|
|
from .data_fetcher import read_hacker_news_rss, format_published_time |
|
|
from .model_trainer import ( |
|
|
authenticate_hf, |
|
|
train_with_dataset, |
|
|
get_top_hits, |
|
|
load_embedding_model, |
|
|
upload_model_to_hub |
|
|
) |
|
|
from .config import AppConfig |
|
|
from .vibe_logic import VibeChecker |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
class HackerNewsFineTuner: |
|
|
""" |
|
|
Encapsulates all application logic and state for a single user session. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: AppConfig = AppConfig): |
|
|
|
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.session_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
self.session_root = self.config.ARTIFACTS_DIR / self.session_id |
|
|
self.output_dir = self.session_root / "embedding_gemma_finetuned" |
|
|
self.dataset_export_file = self.session_root / "training_dataset.csv" |
|
|
|
|
|
|
|
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}") |
|
|
|
|
|
|
|
|
self.model: Optional[SentenceTransformer] = None |
|
|
self.vibe_checker: Optional[VibeChecker] = None |
|
|
self.titles: List[str] = [] |
|
|
self.last_hn_dataset: List[List[str]] = [] |
|
|
self.imported_dataset: List[List[str]] = [] |
|
|
|
|
|
|
|
|
authenticate_hf(self.config.HF_TOKEN) |
|
|
|
|
|
def _update_vibe_checker(self): |
|
|
"""Initializes or updates the VibeChecker with the current model state.""" |
|
|
if self.model: |
|
|
self.vibe_checker = VibeChecker( |
|
|
model=self.model, |
|
|
query_anchor=self.config.QUERY_ANCHOR, |
|
|
task_name=self.config.TASK_NAME |
|
|
) |
|
|
else: |
|
|
self.vibe_checker = None |
|
|
|
|
|
|
|
|
|
|
|
def refresh_data_and_model(self) -> Tuple[List[str], str]: |
|
|
""" |
|
|
Reloads model and fetches data. |
|
|
Returns: |
|
|
- List of titles (for the UI) |
|
|
- Status message string |
|
|
""" |
|
|
print(f"[{self.session_id}] Reloading model and data...") |
|
|
|
|
|
self.last_hn_dataset = [] |
|
|
self.imported_dataset = [] |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = load_embedding_model(self.config.MODEL_NAME) |
|
|
self._update_vibe_checker() |
|
|
except Exception as e: |
|
|
error_msg = f"CRITICAL ERROR: Model failed to load. {e}" |
|
|
print(error_msg) |
|
|
self.model = None |
|
|
self._update_vibe_checker() |
|
|
return [], error_msg |
|
|
|
|
|
|
|
|
news_feed, status_msg = read_hacker_news_rss(self.config) |
|
|
titles_out = [] |
|
|
status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}" |
|
|
|
|
|
if news_feed is not None and news_feed.entries: |
|
|
titles_out = [item.title for item in news_feed.entries] |
|
|
else: |
|
|
titles_out = ["Error fetching news."] |
|
|
gr.Warning(f"Data reload failed. {status_msg}") |
|
|
|
|
|
self.titles = titles_out |
|
|
|
|
|
|
|
|
return self.titles, status_value |
|
|
|
|
|
|
|
|
def import_additional_dataset(self, file_path: str) -> str: |
|
|
if not file_path: |
|
|
return "Please upload a CSV file." |
|
|
new_dataset, num_imported = [], 0 |
|
|
try: |
|
|
with open(file_path, 'r', newline='', encoding='utf-8') as f: |
|
|
reader = csv.reader(f) |
|
|
try: |
|
|
header = next(reader) |
|
|
|
|
|
if not (header and header[0].lower().strip() == 'anchor'): |
|
|
f.seek(0) |
|
|
except StopIteration: |
|
|
return "Error: Uploaded file is empty." |
|
|
|
|
|
for row in reader: |
|
|
if len(row) == 3: |
|
|
new_dataset.append([s.strip() for s in row]) |
|
|
num_imported += 1 |
|
|
if num_imported == 0: |
|
|
raise ValueError("No valid rows found.") |
|
|
self.imported_dataset = new_dataset |
|
|
return f"Imported {num_imported} triplets." |
|
|
except Exception as e: |
|
|
return f"Import failed: {e}" |
|
|
|
|
|
def export_dataset(self) -> Optional[str]: |
|
|
if not self.last_hn_dataset: |
|
|
gr.Warning("No dataset generated yet.") |
|
|
return None |
|
|
|
|
|
file_path = self.dataset_export_file |
|
|
try: |
|
|
with open(file_path, 'w', newline='', encoding='utf-8') as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(['Anchor', 'Positive', 'Negative']) |
|
|
writer.writerows(self.last_hn_dataset) |
|
|
gr.Info(f"Dataset exported.") |
|
|
return str(file_path) |
|
|
except Exception as e: |
|
|
gr.Error(f"Export failed: {e}") |
|
|
return None |
|
|
|
|
|
def download_model(self) -> Optional[str]: |
|
|
if not os.path.exists(self.output_dir): |
|
|
gr.Warning("No model trained yet.") |
|
|
return None |
|
|
|
|
|
timestamp = int(time.time()) |
|
|
try: |
|
|
base_name = self.session_root / f"model_finetuned_{timestamp}" |
|
|
archive_path = shutil.make_archive( |
|
|
base_name=str(base_name), |
|
|
format='zip', |
|
|
root_dir=self.output_dir, |
|
|
) |
|
|
gr.Info(f"Model zipped.") |
|
|
return archive_path |
|
|
except Exception as e: |
|
|
gr.Error(f"Zip failed: {e}") |
|
|
return None |
|
|
|
|
|
def upload_model(self, repo_name: str, oauth_token_str: str) -> str: |
|
|
""" |
|
|
Calls the model trainer upload function using the session's output directory. |
|
|
""" |
|
|
if not os.path.exists(self.output_dir): |
|
|
return "❌ Error: No trained model found in this session. Run training first." |
|
|
if not repo_name.strip(): |
|
|
return "❌ Error: Please specify a repository name." |
|
|
|
|
|
return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str) |
|
|
|
|
|
|
|
|
|
|
|
def _create_hn_dataset(self, pos_ids: List[int], neg_ids: List[int]) -> List[List[str]]: |
|
|
""" |
|
|
Creates triplets (Anchor, Positive, Negative) from the selected indices. |
|
|
Uses cycling to balance the dataset if the number of positives != negatives. |
|
|
""" |
|
|
if not pos_ids or not neg_ids: |
|
|
return [] |
|
|
|
|
|
|
|
|
pos_titles = [self.titles[i] for i in pos_ids] |
|
|
neg_titles = [self.titles[i] for i in neg_ids] |
|
|
|
|
|
dataset = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(pos_titles) >= len(neg_titles): |
|
|
|
|
|
neg_cycle = cycle(neg_titles) |
|
|
for p_title in pos_titles: |
|
|
dataset.append([self.config.QUERY_ANCHOR, p_title, next(neg_cycle)]) |
|
|
else: |
|
|
|
|
|
pos_cycle = cycle(pos_titles) |
|
|
for n_title in neg_titles: |
|
|
dataset.append([self.config.QUERY_ANCHOR, next(pos_cycle), n_title]) |
|
|
|
|
|
return dataset |
|
|
|
|
|
def training(self, pos_ids: List[int], neg_ids: List[int]) -> str: |
|
|
""" |
|
|
Main training entry point. |
|
|
Args: |
|
|
pos_ids: Indices of stories marked as "Favorite" |
|
|
neg_ids: Indices of stories marked as "Dislike" |
|
|
""" |
|
|
if self.model is None: |
|
|
raise gr.Error("Model not loaded.") |
|
|
|
|
|
if self.imported_dataset: |
|
|
self.last_hn_dataset = self.imported_dataset |
|
|
else: |
|
|
|
|
|
if not pos_ids: |
|
|
raise gr.Error("Please select at least one 'Favorite' story.") |
|
|
if not neg_ids: |
|
|
raise gr.Error("Please select at least one 'Dislike' story.") |
|
|
|
|
|
|
|
|
self.last_hn_dataset = self._create_hn_dataset(pos_ids, neg_ids) |
|
|
|
|
|
if not self.last_hn_dataset: |
|
|
raise gr.Error("Dataset generation failed (Empty dataset).") |
|
|
|
|
|
def semantic_search_fn() -> str: |
|
|
return get_top_hits(model=self.model, target_titles=self.titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR) |
|
|
|
|
|
result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n" |
|
|
print(f"[{self.session_id}] Starting Training with {len(self.last_hn_dataset)} examples...") |
|
|
|
|
|
train_with_dataset( |
|
|
model=self.model, |
|
|
dataset=self.last_hn_dataset, |
|
|
output_dir=self.output_dir, |
|
|
task_name=self.config.TASK_NAME, |
|
|
search_fn=semantic_search_fn |
|
|
) |
|
|
|
|
|
self._update_vibe_checker() |
|
|
print(f"[{self.session_id}] Training Complete.") |
|
|
|
|
|
result += "### Search (After):\n" + f"{semantic_search_fn()}" |
|
|
return result |
|
|
|
|
|
def is_model_tuned(self) -> bool: |
|
|
return True if self.last_hn_dataset else False |
|
|
|
|
|
|
|
|
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]: |
|
|
model_name = "<unsaved>" |
|
|
if self.last_hn_dataset: |
|
|
model_name = f"./{self.output_dir}" |
|
|
|
|
|
info_text = (f"**Session:** {self.session_id[:6]}<br>" |
|
|
f"**Base Model:** `{self.config.MODEL_NAME}`<br>" |
|
|
f"**Tuned Model:** `{model_name}`") |
|
|
|
|
|
if not self.vibe_checker: |
|
|
return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_css("gray")), info_text |
|
|
if not news_text or len(news_text.split()) < 3: |
|
|
return "N/A", "Text too short", gr.update(value=self._generate_vibe_css("gray")), info_text |
|
|
|
|
|
try: |
|
|
vibe_result = self.vibe_checker.check(news_text) |
|
|
status = vibe_result.status_html.split('>')[1].split('<')[0] |
|
|
return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_css(vibe_result.color_hsl)), info_text |
|
|
except Exception as e: |
|
|
return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_css("gray")), info_text |
|
|
|
|
|
def _generate_vibe_css(self, color: str) -> str: |
|
|
"""Generates a style block to update the Mood Lamp textbox background.""" |
|
|
return f"<style>#mood_lamp input {{ background-color: {color} !important; transition: background-color 0.5s ease; }}</style>" |
|
|
|
|
|
|
|
|
def fetch_and_display_mood_feed(self) -> str: |
|
|
if not self.vibe_checker: |
|
|
return "Model not ready. Please wait or reload." |
|
|
|
|
|
feed, status = read_hacker_news_rss(self.config) |
|
|
if not feed or not feed.entries: |
|
|
return f"**Feed Error:** {status}" |
|
|
|
|
|
scored_entries = [] |
|
|
for entry in feed.entries: |
|
|
title = entry.get('title') |
|
|
if not title: continue |
|
|
|
|
|
vibe_result = self.vibe_checker.check(title) |
|
|
scored_entries.append({ |
|
|
"title": title, |
|
|
"link": entry.get('link', '#'), |
|
|
"comments": entry.get('comments', '#'), |
|
|
"published": format_published_time(entry.published_parsed), |
|
|
"mood": vibe_result |
|
|
}) |
|
|
|
|
|
scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True) |
|
|
|
|
|
model_name = "<unsaved>" |
|
|
if self.last_hn_dataset: |
|
|
model_name = f"./{self.output_dir}" |
|
|
|
|
|
md = (f"## Hacker News Top Stories\n" |
|
|
f"**Session:** {self.session_id[:6]}<br>" |
|
|
f"**Base Model:** `{self.config.MODEL_NAME}`<br>" |
|
|
f"**Tuned Model:** `{model_name}`<br>" |
|
|
f"**Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" |
|
|
"| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n") |
|
|
|
|
|
for item in scored_entries: |
|
|
md += (f"| {item['mood'].status_html} " |
|
|
f"| {item['mood'].raw_score:.4f} " |
|
|
f"| [{item['title']}]({item['link']}) " |
|
|
f"| [Comments]({item['comments']}) " |
|
|
f"| {item['published']} |\n") |
|
|
return md |
|
|
|