Spaces:
Paused
Paused
import os | |
# Set memory optimization environment variables BEFORE importing torch | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid tokenizer warning | |
import torch | |
import pandas as pd | |
import glob | |
import threading | |
import logging | |
import gradio as gr | |
from datasets import Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel, prepare_model_for_kbit_training, LoraConfig, get_peft_model | |
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
import warnings | |
import subprocess | |
import gc | |
import psutil | |
warnings.filterwarnings("ignore") | |
# ===== CONFIG ===== | |
AVAILABLE_MODELS = { | |
"microsoft/Phi-3-mini-4k-instruct": { | |
"name": "Phi-3 Mini 4K", | |
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"] | |
}, | |
"openai/gpt-oss-20b": { | |
"name": "OpenAI GPT-OSS 20B", | |
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"] # MoE architecture, may need adjustment | |
} | |
} | |
# Default model | |
current_model_id = "openai/gpt-oss-20b" | |
MODEL_ID = current_model_id | |
OUTPUT_DIR = "./phi3-relevance-checkpoints" | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
HF_USERNAME = os.environ.get("HF_USERNAME", "") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def kill_gpu_processes(): | |
"""Kill other processes using GPU to free memory""" | |
if not torch.cuda.is_available(): | |
return | |
current_pid = os.getpid() | |
killed = [] | |
try: | |
# Get GPU processes using nvidia-smi | |
result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid', '--format=csv,noheader'], | |
capture_output=True, text=True) | |
if result.returncode == 0: | |
pids = [int(pid.strip()) for pid in result.stdout.strip().split('\n') if pid.strip()] | |
for pid in pids: | |
if pid == current_pid: | |
continue | |
try: | |
process = psutil.Process(pid) | |
process_name = process.name() | |
process.terminate() | |
killed.append(f"PID {pid} ({process_name})") | |
logger.info(f"Killed GPU process: {pid} ({process_name})") | |
except: | |
pass | |
except Exception as e: | |
logger.warning(f"Could not kill GPU processes: {e}") | |
return killed | |
def clear_gpu_memory(): | |
"""Aggressively clear GPU memory""" | |
if torch.cuda.is_available(): | |
# Kill other GPU processes first | |
kill_gpu_processes() | |
# Clear PyTorch cache | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
# Force garbage collection | |
gc.collect() | |
# Clear any remaining tensor references | |
for obj in gc.get_objects(): | |
try: | |
if torch.is_tensor(obj) and obj.is_cuda: | |
del obj | |
except: | |
pass | |
gc.collect() | |
torch.cuda.empty_cache() | |
logger.info("GPU memory cleared") | |
# Globals | |
current_model = None | |
current_tokenizer = None | |
train_df = None | |
test_df = None | |
training_status = {"status": "idle", "progress": 0, "logs": [], "model_repo": None} | |
trained_models = [] # Keep track of all trained models | |
training_lock = threading.Lock() | |
def format_prompt(query, title, content): | |
"""Format the prompt for the model""" | |
# Convert all inputs to strings and handle NaN/None values | |
query = str(query) if query is not None and str(query) != 'nan' else "" | |
title = str(title) if title is not None and str(title) != 'nan' else "" | |
content = str(content) if content is not None and str(content) != 'nan' else "" | |
# Truncate content if too long | |
if len(content) > 1000: | |
content = content[:1000] + "..." | |
return f"""You would get a query and document's title and content and return yes (if the document is relevant to the query) or no (if the document is not relevant to the query). | |
Answer only yes / no. | |
Document: | |
####DOCUMENT START | |
title: {title} | |
content: {content} | |
####DOCUMENT END | |
Query: | |
####Query START | |
{query} | |
####Query END | |
ANSWER: """ | |
def load_model_and_tokenizer(checkpoint_path=None, model_id=None): | |
"""Load model and tokenizer""" | |
global current_model, current_tokenizer, current_model_id | |
if model_id is None: | |
model_id = current_model_id | |
# Clear existing model from memory first | |
if current_model is not None: | |
del current_model | |
current_model = None | |
# Aggressively clear GPU memory before loading new model | |
clear_gpu_memory() | |
logger.info(f"Loading model from {checkpoint_path or model_id} ...") | |
# Load tokenizer | |
current_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if current_tokenizer.pad_token is None: | |
current_tokenizer.pad_token = current_tokenizer.eos_token | |
current_tokenizer.padding_side = "left" | |
# Load model with appropriate settings | |
model_kwargs = { | |
"trust_remote_code": True | |
} | |
# Model loading configuration | |
if model_id == "openai/gpt-oss-20b": | |
# GPT-OSS-20B already uses MXFP4 quantization natively | |
# Don't apply additional quantization | |
model_kwargs["torch_dtype"] = torch.bfloat16 | |
model_kwargs["device_map"] = "auto" | |
model_kwargs["low_cpu_mem_usage"] = True | |
logger.info("Loading GPT-OSS-20B with native MXFP4 quantization") | |
else: | |
# For other models, use 4-bit quantization | |
model_kwargs["load_in_4bit"] = True | |
model_kwargs["torch_dtype"] = torch.float16 | |
model_kwargs["device_map"] = "auto" | |
try: | |
if checkpoint_path and os.path.exists(checkpoint_path): | |
# Load base model | |
base_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
# Load PEFT adapter | |
current_model = PeftModel.from_pretrained(base_model, checkpoint_path) | |
logger.info("Loaded PEFT model from checkpoint") | |
else: | |
# Load base model | |
current_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
logger.info("Loaded base model") | |
except RuntimeError as e: | |
if "meta tensor" in str(e): | |
logger.warning("Meta tensor error detected, trying alternative loading approach...") | |
# Try loading without device map first, then move to GPU | |
model_kwargs.pop("device_map", None) | |
model_kwargs.pop("max_memory", None) | |
if checkpoint_path and os.path.exists(checkpoint_path): | |
base_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
base_model = base_model.to("cuda") | |
current_model = PeftModel.from_pretrained(base_model, checkpoint_path) | |
else: | |
current_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
current_model = current_model.to("cuda") | |
logger.info("Loaded model with alternative approach") | |
else: | |
raise | |
current_model.eval() | |
return current_model, current_tokenizer | |
def get_gpu_memory_status(): | |
"""Get current GPU memory usage""" | |
if not torch.cuda.is_available(): | |
return "No GPU available" | |
allocated = torch.cuda.memory_allocated() / 1024**3 | |
reserved = torch.cuda.memory_reserved() / 1024**3 | |
total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
return f"GPU Memory: {allocated:.1f}GB allocated / {reserved:.1f}GB reserved / {total:.1f}GB total" | |
def get_trained_models_list(): | |
"""Get formatted list of all trained models""" | |
if not trained_models: | |
return "No models trained yet in this session.\n\nPreviously trained models on HuggingFace:\n- amos1088/phi3-dpo-relevance" | |
text = "## Trained Models in This Session:\n\n" | |
for i, model in enumerate(trained_models, 1): | |
text += f"{i}. **{model['repo']}**\n" | |
text += f" - Accuracy: {model['accuracy']:.2%}\n" | |
text += f" - Predictions: Yes {model['yes_ratio']:.1%}, No {model['no_ratio']:.1%}\n" | |
text += f" - LR: {model.get('lr', 'N/A')}, Model: {model['model_id'].split('/')[-1]}\n" | |
text += f" - Link: https://huggingface.co/{model['repo']}\n\n" | |
return text | |
def switch_model(model_id): | |
"""Switch to a different model""" | |
global current_model, current_tokenizer, current_model_id, OUTPUT_DIR | |
# Clear current model | |
if current_model is not None: | |
del current_model | |
current_model = None | |
if current_tokenizer is not None: | |
del current_tokenizer | |
current_tokenizer = None | |
# Update model ID and output directory | |
current_model_id = model_id | |
model_name = model_id.split("/")[-1] | |
OUTPUT_DIR = f"./{model_name}-relevance-checkpoints" | |
# Clear CUDA cache | |
clear_gpu_memory() | |
logger.info(f"Switched to model: {model_id}") | |
return f"Model switched to: {AVAILABLE_MODELS[model_id]['name']}\n{get_gpu_memory_status()}" | |
def collate_fn(batch): | |
"""Collate function for DataLoader""" | |
prompts = [item["prompt"] for item in batch] | |
labels = [item["label"] for item in batch] # "yes" or "no" | |
# Create full texts (prompt + label) | |
full_texts = [prompt + label for prompt, label in zip(prompts, labels)] | |
# Tokenize full sequences | |
model_inputs = current_tokenizer( | |
full_texts, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
# Create labels by masking the prompt portion (-100 means ignore in loss) | |
labels = [] | |
for i, prompt in enumerate(prompts): | |
# Tokenize prompt alone to find where answer starts | |
prompt_ids = current_tokenizer( | |
prompt, | |
add_special_tokens=True, | |
truncation=True, | |
max_length=512 | |
).input_ids | |
# Create label sequence | |
label_ids = model_inputs.input_ids[i].clone() | |
# Mask prompt tokens (set to -100) | |
label_ids[:len(prompt_ids)] = -100 | |
labels.append(label_ids) | |
labels = torch.stack(labels) | |
return { | |
"input_ids": model_inputs.input_ids, | |
"attention_mask": model_inputs.attention_mask, | |
"labels": labels | |
} | |
def prepare_finetuning_dataset(df): | |
"""Convert 4-category labels to standard fine-tuning format""" | |
ft_data = [] | |
# Map 4 categories to yes/no | |
label_mapping = { | |
'easy_positive': 'yes', | |
'hard_positive': 'yes', | |
'easy_negative': 'no', | |
'hard_negative': 'no', | |
'yes': 'yes', | |
'no': 'no' | |
} | |
for _, row in df.iterrows(): | |
# Handle both old and new column names | |
if 'query_text' in row: | |
query = str(row['query_text']) if pd.notna(row['query_text']) else '' | |
title = str(row['title']) if pd.notna(row['title']) else '' | |
content = str(row['text']) if pd.notna(row['text']) else '' | |
else: | |
query = str(row.get('query', '')) if pd.notna(row.get('query', '')) else '' | |
title = str(row.get('title', '')) if pd.notna(row.get('title', '')) else '' | |
content = str(row.get('content', '')) if pd.notna(row.get('content', '')) else '' | |
# Create prompt if not exists | |
if 'prompt' in row: | |
prompt = row['prompt'] | |
else: | |
prompt = format_prompt(query, title, content) | |
# Get mapped label | |
original_label = row['label'] | |
mapped_label = label_mapping.get(original_label, original_label) | |
# Create the full text with prompt and answer | |
text = prompt + mapped_label | |
ft_data.append({ | |
'text': text, | |
'prompt': prompt, | |
'label': mapped_label, | |
'original_label': original_label # Keep original for analysis | |
}) | |
return pd.DataFrame(ft_data) | |
def train_model(train_df, val_df, epochs=5, batch_size=32, lr=5e-6, max_samples=None): | |
"""Standard fine-tuning for document relevance classification""" | |
global current_model, current_tokenizer | |
# Clear GPU memory before training | |
logger.info("Clearing GPU memory before training...") | |
clear_gpu_memory() | |
# Load model and tokenizer if not already loaded | |
if current_model is None or current_tokenizer is None: | |
load_model_and_tokenizer() | |
# Limit training samples if specified (for memory constraints) | |
if max_samples and len(train_df) > max_samples: | |
logger.info(f"Limiting training data from {len(train_df)} to {max_samples} samples") | |
train_df = train_df.sample(n=max_samples, random_state=42) | |
val_df = val_df.head(min(len(val_df), max_samples // 5)) # Proportional validation set | |
# Convert to fine-tuning format | |
logger.info("Preparing fine-tuning dataset...") | |
ft_train_df = prepare_finetuning_dataset(train_df) | |
ft_val_df = prepare_finetuning_dataset(val_df) | |
# Create datasets | |
train_dataset = Dataset.from_pandas(ft_train_df) | |
val_dataset = Dataset.from_pandas(ft_val_df) | |
# Prepare model for training | |
if hasattr(current_model, 'is_loaded_in_4bit') and current_model.is_loaded_in_4bit: | |
# For 4-bit models | |
current_model = prepare_model_for_kbit_training(current_model) | |
else: | |
# For full precision models (including GPT-OSS with native quantization) | |
# Enable gradient checkpointing to save memory | |
if hasattr(current_model, 'gradient_checkpointing_enable'): | |
current_model.gradient_checkpointing_enable() | |
# For GPT-OSS, also enable some memory optimizations | |
if current_model_id == "openai/gpt-oss-20b" and hasattr(current_model.config, 'use_cache'): | |
current_model.config.use_cache = False | |
# Configure LoRA - can use larger rank on A100 | |
target_modules = AVAILABLE_MODELS[current_model_id]["target_modules"] | |
if current_model_id == "openai/gpt-oss-20b": | |
# Conservative for large model | |
lora_r = 8 | |
lora_alpha = 16 | |
else: | |
# Can use larger rank for smaller models | |
lora_r = 16 | |
lora_alpha = 32 | |
peft_config = LoraConfig( | |
r=lora_r, | |
lora_alpha=lora_alpha, | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=target_modules | |
) | |
logger.info(f"Starting fine-tuning with {len(train_df)} train samples, {len(val_df)} val samples") | |
logger.info(f"Learning rate: {lr}, Effective batch size: {batch_size}, Epochs: {epochs}") | |
# Create output directory | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# Training configuration optimized for standard fine-tuning | |
target_batch_size = batch_size # Target effective batch size | |
if current_model_id == "openai/gpt-oss-20b": | |
# For GPT-OSS-20B: use smaller per-device batch with gradient accumulation | |
actual_batch_size = 2 # Per-device batch size | |
seq_length = 512 # Standard sequence length | |
grad_accum = target_batch_size // actual_batch_size # 16 gradient accumulation steps | |
else: | |
# For smaller models like Phi-3 - can use larger per-device batch | |
actual_batch_size = min(16, target_batch_size) # Up to 16 per device | |
grad_accum = max(1, target_batch_size // actual_batch_size) # Accumulate if needed | |
seq_length = 512 | |
# Tokenize the datasets | |
def tokenize_function(examples): | |
# Tokenize the full texts (prompt + answer) | |
model_inputs = current_tokenizer( | |
examples['text'], | |
truncation=True, | |
padding="max_length", | |
max_length=seq_length, | |
return_tensors=None | |
) | |
# For causal LM, labels are the same as input_ids | |
model_inputs["labels"] = model_inputs["input_ids"].copy() | |
# Store metadata for evaluation | |
model_inputs["original_labels"] = examples['original_label'] | |
model_inputs["mapped_labels"] = examples['label'] | |
return model_inputs | |
# Tokenize datasets | |
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names) | |
tokenized_val = val_dataset.map(tokenize_function, batched=True, remove_columns=val_dataset.column_names) | |
# Standard training arguments | |
training_args = TrainingArguments( | |
output_dir=OUTPUT_DIR, | |
num_train_epochs=epochs, | |
per_device_train_batch_size=actual_batch_size, | |
per_device_eval_batch_size=actual_batch_size, | |
gradient_accumulation_steps=grad_accum, | |
gradient_checkpointing=True, | |
learning_rate=lr, | |
lr_scheduler_type="cosine", | |
warmup_steps=500, # More warmup for standard fine-tuning | |
logging_steps=10, | |
save_strategy="epoch", | |
eval_strategy="epoch", | |
bf16=True, | |
fp16=False, | |
weight_decay=0.01, | |
optim="adamw_torch", | |
save_total_limit=3, | |
load_best_model_at_end=True, | |
metric_for_best_model="eval_loss", | |
greater_is_better=False, | |
report_to=[], | |
run_name="standard-ft-relevance", | |
dataloader_num_workers=2, | |
) | |
# Create data collator | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=current_tokenizer, | |
mlm=False, # Causal LM, not masked LM | |
pad_to_multiple_of=8 | |
) | |
# Apply LoRA to the model | |
current_model = get_peft_model(current_model, peft_config) | |
current_model.print_trainable_parameters() | |
# Create standard trainer | |
trainer = Trainer( | |
model=current_model, | |
args=training_args, | |
train_dataset=tokenized_train, | |
eval_dataset=tokenized_val, | |
data_collator=data_collator, | |
tokenizer=current_tokenizer, | |
) | |
# Custom logging callback | |
def log_callback(logs): | |
if "loss" in logs: | |
step = logs.get("epoch", 0) | |
loss = logs["loss"] | |
with training_lock: | |
training_status["logs"].append(f"Step {step}, Loss: {loss:.4f}") | |
if "eval_loss" in logs: | |
training_status["logs"].append(f"Eval Loss: {logs['eval_loss']:.4f}") | |
# Custom callback for status updates | |
from transformers import TrainerCallback | |
import numpy as np | |
def compute_accuracy_metrics(trainer, eval_dataset, ft_val_df, num_samples=100): | |
"""Compute accuracy metrics and confusion matrix on a subset of eval data""" | |
logger.info("Computing accuracy metrics...") | |
try: | |
# Get the original dataframe for easier access to prompts and labels | |
eval_df = ft_val_df | |
# Sample subset for faster evaluation | |
sample_size = min(num_samples, len(eval_df)) | |
sample_df = eval_df.sample(n=sample_size, random_state=42) | |
# Initialize confusion matrix counters | |
confusion_matrix = { | |
'easy_positive': {'yes': 0, 'no': 0}, | |
'hard_positive': {'yes': 0, 'no': 0}, | |
'easy_negative': {'yes': 0, 'no': 0}, | |
'hard_negative': {'yes': 0, 'no': 0} | |
} | |
predictions_yes = 0 | |
predictions_no = 0 | |
correct = 0 | |
for idx, row in sample_df.iterrows(): | |
prompt = row['prompt'] | |
true_label = row['label'] # This is the mapped label (yes/no) | |
original_label = row['original_label'] # Get original 4-category label | |
# Tokenize and run inference | |
inputs = current_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
inputs = {k: v.to(trainer.model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = trainer.model(**inputs) | |
logits = outputs.logits[0, -1, :] | |
# Get token IDs | |
yes_token_id = current_tokenizer.encode("yes", add_special_tokens=False)[0] | |
no_token_id = current_tokenizer.encode("no", add_special_tokens=False)[0] | |
yes_logit = logits[yes_token_id].item() | |
no_logit = logits[no_token_id].item() | |
# Get prediction | |
prediction = "yes" if yes_logit > no_logit else "no" | |
if prediction == "yes": | |
predictions_yes += 1 | |
else: | |
predictions_no += 1 | |
if prediction == true_label: | |
correct += 1 | |
# Update confusion matrix if we have original label | |
if original_label and original_label in confusion_matrix: | |
confusion_matrix[original_label][prediction] += 1 | |
accuracy = correct / len(sample_df) | |
yes_ratio = predictions_yes / len(sample_df) | |
no_ratio = predictions_no / len(sample_df) | |
# Calculate per-category accuracies | |
category_accuracies = {} | |
for category in confusion_matrix: | |
total = confusion_matrix[category]['yes'] + confusion_matrix[category]['no'] | |
if total > 0: | |
if category in ['easy_positive', 'hard_positive']: | |
# For positive categories, correct prediction is 'yes' | |
category_accuracies[category] = confusion_matrix[category]['yes'] / total | |
else: | |
# For negative categories, correct prediction is 'no' | |
category_accuracies[category] = confusion_matrix[category]['no'] / total | |
else: | |
category_accuracies[category] = 0.0 | |
return { | |
'accuracy': accuracy, | |
'yes_ratio': yes_ratio, | |
'no_ratio': no_ratio, | |
'total_samples': len(sample_df), | |
'confusion_matrix': confusion_matrix, | |
'category_accuracies': category_accuracies | |
} | |
except Exception as e: | |
logger.error(f"Error computing accuracy metrics: {e}") | |
import traceback | |
traceback.print_exc() | |
return { | |
'accuracy': 0.0, | |
'yes_ratio': 0.0, | |
'no_ratio': 0.0, | |
'total_samples': 0, | |
'confusion_matrix': {}, | |
'category_accuracies': {} | |
} | |
class StatusCallback(TrainerCallback): | |
def __init__(self, trainer, eval_dataset, ft_val_df): | |
self.trainer = trainer | |
self.eval_dataset = eval_dataset | |
self.ft_val_df = ft_val_df | |
self.eval_every_n_steps = 50 # Evaluate every 50 steps | |
self.last_eval_step = 0 | |
def on_step_end(self, args, state, control, **kwargs): | |
"""Called at the end of each training step""" | |
if state.global_step > 0 and state.global_step % self.eval_every_n_steps == 0 and state.global_step != self.last_eval_step: | |
self.last_eval_step = state.global_step | |
logger.info(f"on_step_end: Step {state.global_step} - Computing metrics") | |
with training_lock: | |
training_status["logs"].append(f"[on_step_end] Computing accuracy at step {state.global_step}...") | |
try: | |
metrics = compute_accuracy_metrics(self.trainer, self.eval_dataset, self.ft_val_df) | |
training_status["logs"].append( | |
f"Step {state.global_step} Metrics: " | |
f"Accuracy={metrics['accuracy']:.2%}, " | |
f"Yes={metrics['yes_ratio']:.1%}, " | |
f"No={metrics['no_ratio']:.1%}" | |
) | |
# Add confusion matrix info if available | |
if 'confusion_matrix' in metrics and metrics['confusion_matrix']: | |
training_status["logs"].append("\n=== Confusion Matrix ===") | |
for category, preds in metrics['confusion_matrix'].items(): | |
total = preds['yes'] + preds['no'] | |
if total > 0: | |
acc = metrics['category_accuracies'][category] | |
training_status["logs"].append( | |
f"{category}: Yes={preds['yes']}, No={preds['no']} (Acc: {acc:.1%})" | |
) | |
except Exception as e: | |
logger.error(f"Error in on_step_end: {e}") | |
training_status["logs"].append(f"Error computing metrics: {e}") | |
def on_log(self, args, state, control, logs=None, **kwargs): | |
if logs: | |
with training_lock: | |
if "loss" in logs: | |
training_status["logs"].append(f"Step {state.global_step}: Loss = {logs['loss']:.4f}") | |
if "eval_loss" in logs: | |
training_status["logs"].append(f"Eval Loss = {logs['eval_loss']:.4f}") | |
# Compute accuracy metrics periodically | |
if state.global_step > 0 and state.global_step % self.eval_every_n_steps == 0: | |
logger.info(f"Step {state.global_step}: Computing accuracy metrics...") | |
training_status["logs"].append(f"Computing accuracy at step {state.global_step}...") | |
metrics = compute_accuracy_metrics(self.trainer, self.eval_dataset, self.ft_val_df) | |
training_status["logs"].append( | |
f"Step {state.global_step} Metrics: " | |
f"Accuracy={metrics['accuracy']:.2%}, " | |
f"Yes={metrics['yes_ratio']:.1%}, " | |
f"No={metrics['no_ratio']:.1%}" | |
) | |
# Add confusion matrix info if available | |
if 'confusion_matrix' in metrics: | |
training_status["logs"].append("\n=== Confusion Matrix ===") | |
for category, preds in metrics['confusion_matrix'].items(): | |
total = preds['yes'] + preds['no'] | |
if total > 0: | |
acc = metrics['category_accuracies'][category] | |
training_status["logs"].append( | |
f"{category}: Yes={preds['yes']}, No={preds['no']} (Acc: {acc:.1%})" | |
) | |
# Warn if model is biased | |
if metrics['yes_ratio'] < 0.2 or metrics['no_ratio'] < 0.2: | |
training_status["logs"].append( | |
f"⚠️ WARNING: Model is heavily biased! " | |
f"(Yes: {metrics['yes_ratio']:.1%}, No: {metrics['no_ratio']:.1%})" | |
) | |
# Update progress | |
if state.global_step > 0: | |
total_steps = len(train_dataset) // batch_size * epochs | |
training_status["progress"] = min(int((state.global_step / total_steps) * 100), 99) | |
# Add callback with trainer and eval dataset | |
status_callback = StatusCallback(trainer, val_dataset, ft_val_df) | |
trainer.add_callback(status_callback) | |
# Train | |
try: | |
logger.info("Starting fine-tuning...") | |
trainer.train() | |
# Save final model | |
save_path = os.path.join(OUTPUT_DIR, "final") | |
trainer.save_model(save_path) | |
current_tokenizer.save_pretrained(save_path) | |
logger.info(f"Model saved to {save_path}") | |
# Compute final metrics | |
logger.info("Computing final accuracy metrics...") | |
final_metrics = compute_accuracy_metrics(trainer, val_dataset, ft_val_df, num_samples=200) | |
logger.info(f"Final Accuracy: {final_metrics['accuracy']:.2%}") | |
logger.info(f"Final Prediction Distribution - Yes: {final_metrics['yes_ratio']:.1%}, No: {final_metrics['no_ratio']:.1%}") | |
with training_lock: | |
training_status["logs"].append(f"\n=== FINAL RESULTS ===") | |
training_status["logs"].append(f"Overall Accuracy: {final_metrics['accuracy']:.2%}") | |
training_status["logs"].append(f"Yes predictions: {final_metrics['yes_ratio']:.1%}") | |
training_status["logs"].append(f"No predictions: {final_metrics['no_ratio']:.1%}") | |
# Add final confusion matrix | |
if 'confusion_matrix' in final_metrics: | |
training_status["logs"].append("\n=== Final Confusion Matrix ===") | |
for category, preds in final_metrics['confusion_matrix'].items(): | |
total = preds['yes'] + preds['no'] | |
if total > 0: | |
acc = final_metrics['category_accuracies'][category] | |
training_status["logs"].append( | |
f"{category}: Yes={preds['yes']}, No={preds['no']} (Accuracy: {acc:.1%})" | |
) | |
# Update global model reference | |
current_model = trainer.model | |
current_model.eval() | |
# Push to hub if token available | |
if HF_TOKEN and HF_USERNAME: | |
try: | |
# Generate unique repo name with timestamp | |
from datetime import datetime | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
model_short_name = current_model_id.split("/")[-1] | |
# Create descriptive repo name with training details | |
repo_name = f"{HF_USERNAME}/{model_short_name}-relevance-ft-{timestamp}" | |
# Create model card with training information | |
model_card_content = f"""--- | |
tags: | |
- document-relevance | |
- dpo | |
- {model_short_name} | |
datasets: | |
- custom-relevance-dataset | |
metrics: | |
- accuracy | |
model-index: | |
- name: {repo_name.split('/')[-1]} | |
results: | |
- task: | |
type: text-classification | |
name: Document Relevance Classification | |
metrics: | |
- type: accuracy | |
value: {final_metrics['accuracy']:.4f} | |
name: Validation Accuracy | |
- type: yes_ratio | |
value: {final_metrics['yes_ratio']:.4f} | |
name: Yes Prediction Ratio | |
- type: no_ratio | |
value: {final_metrics['no_ratio']:.4f} | |
name: No Prediction Ratio | |
--- | |
# {model_short_name} Document Relevance Classifier | |
This model was trained using standard fine-tuning for document relevance classification. | |
## Training Configuration | |
- Base Model: {current_model_id} | |
- Training Type: Standard Fine-tuning | |
- Learning Rate: {training_args.learning_rate} | |
- Batch Size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps} | |
- Epochs: {training_args.num_train_epochs} | |
- Training Samples: {len(train_df)} | |
- Validation Samples: {len(val_df)} | |
## Performance Metrics | |
- **Accuracy**: {final_metrics['accuracy']:.2%} | |
- **Yes Predictions**: {final_metrics['yes_ratio']:.1%} | |
- **No Predictions**: {final_metrics['no_ratio']:.1%} | |
## Usage | |
```python | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
# Load base model | |
model = AutoModelForCausalLM.from_pretrained("{current_model_id}") | |
tokenizer = AutoTokenizer.from_pretrained("{current_model_id}") | |
# Load adapter | |
model = PeftModel.from_pretrained(model, "{HF_USERNAME}/{repo_name.split('/')[-1]}") | |
``` | |
## Training Date | |
{datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')} | |
""" | |
# Push model with model card | |
current_model.push_to_hub( | |
repo_name, | |
use_auth_token=HF_TOKEN, | |
commit_message=f"Standard fine-tuning with lr={training_args.learning_rate}, accuracy={final_metrics['accuracy']:.2%}" | |
) | |
current_tokenizer.push_to_hub(repo_name, use_auth_token=HF_TOKEN) | |
# Save model card | |
try: | |
from huggingface_hub import HfApi | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=model_card_content.encode(), | |
path_in_repo="README.md", | |
repo_id=repo_name, | |
repo_type="model", | |
token=HF_TOKEN | |
) | |
except: | |
pass # Model card upload is optional | |
logger.info(f"Model pushed to hub: {repo_name}") | |
# Save repo name to training status and trained models list | |
with training_lock: | |
training_status["logs"].append(f"Model saved to: https://huggingface.co/{repo_name}") | |
training_status["model_repo"] = repo_name | |
# Add to trained models list | |
trained_models.append({ | |
"repo": repo_name, | |
"timestamp": timestamp, | |
"accuracy": final_metrics['accuracy'], | |
"yes_ratio": final_metrics['yes_ratio'], | |
"no_ratio": final_metrics['no_ratio'], | |
"lr": training_args.learning_rate, | |
"model_id": current_model_id | |
}) | |
except Exception as e: | |
logger.error(f"Failed to push to hub: {e}") | |
except Exception as e: | |
logger.error(f"Training failed: {e}") | |
raise | |
def run_training(csv_path, shuffle_flag=False, split_ratio=0.8): | |
"""Run training from CSV file with balanced sampling""" | |
global train_df, test_df, training_status | |
try: | |
with training_lock: | |
training_status["status"] = "loading" | |
training_status["logs"] = ["Loading data..."] | |
# Load CSV | |
df = pd.read_csv(csv_path) | |
logger.info(f"Loaded {len(df)} samples from {csv_path}") | |
# Clean data - replace NaN values with empty strings | |
text_columns = ['query_text', 'title', 'text', 'query', 'content'] | |
for col in text_columns: | |
if col in df.columns: | |
# Count NaN values before cleaning | |
nan_count = df[col].isna().sum() | |
if nan_count > 0: | |
logger.warning(f"Found {nan_count} NaN values in column '{col}' - replacing with empty strings") | |
df[col] = df[col].fillna('') | |
df[col] = df[col].astype(str) | |
# Check required columns for new format | |
new_format_cols = ['query_text', 'title', 'text', 'label'] | |
old_format_cols = ['query', 'title', 'content', 'label'] | |
if all(col in df.columns for col in new_format_cols): | |
# New format with 4 categories | |
logger.info("Using new CSV format with 4 categories") | |
# Validate labels | |
valid_labels = ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative'] | |
if not all(label in valid_labels for label in df['label'].unique()): | |
raise ValueError(f"Labels must be one of: {valid_labels}") | |
# Create prompt column | |
if 'prompt' not in df.columns: | |
logger.info("Creating prompt column from query_text/title/text") | |
df['prompt'] = df.apply( | |
lambda row: format_prompt(row['query_text'], row['title'], row['text']), | |
axis=1 | |
) | |
# Log label distribution (assuming data is already balanced) | |
logger.info(f"Label distribution: {df['label'].value_counts().to_dict()}") | |
elif all(col in df.columns for col in old_format_cols): | |
# Old format | |
logger.info("Using old CSV format") | |
if 'prompt' not in df.columns: | |
df['prompt'] = df.apply( | |
lambda row: format_prompt(row['query'], row['title'], row['content']), | |
axis=1 | |
) | |
# Validate labels for old format | |
if not all(label in ['yes', 'no'] for label in df['label'].unique()): | |
raise ValueError("Labels must be 'yes' or 'no' for old format") | |
else: | |
raise ValueError(f"CSV must have columns: {new_format_cols} or {old_format_cols}") | |
# Shuffle if requested | |
if shuffle_flag: | |
df = df.sample(frac=1).reset_index(drop=True) | |
# Split data | |
split_idx = int(len(df) * split_ratio) | |
train_df = df.iloc[:split_idx].reset_index(drop=True) | |
test_df = df.iloc[split_idx:].reset_index(drop=True) | |
logger.info(f"Train: {len(train_df)}, Test: {len(test_df)}") | |
logger.info(f"Train label distribution: {train_df['label'].value_counts().to_dict()}") | |
# Start training | |
with training_lock: | |
training_status["status"] = "training" | |
training_status["logs"].append("Starting training...") | |
training_status["logs"].append(f"Train samples: {len(train_df)}") | |
training_status["logs"].append(f"Test samples: {len(test_df)}") | |
# Conservative sample size for GPT-OSS-20B | |
if current_model_id == "openai/gpt-oss-20b": | |
max_samples = 2000 # Start conservative | |
else: | |
max_samples = None | |
train_model(train_df, test_df, epochs=5, batch_size=32, lr=5e-6, max_samples=max_samples) | |
with training_lock: | |
training_status["status"] = "completed" | |
training_status["logs"].append("Training completed!") | |
training_status["progress"] = 100 | |
except Exception as e: | |
logger.error(f"Training failed: {str(e)}") | |
with training_lock: | |
training_status["status"] = "failed" | |
training_status["logs"].append(f"Error: {str(e)}") | |
training_status["error"] = str(e) | |
def run_inference(query, document_title, document_content, checkpoint="latest"): | |
"""Run inference on a single example using logit comparison""" | |
global current_model, current_tokenizer | |
# Validate inputs | |
if not query or not str(query).strip(): | |
return "Error: Query cannot be empty" | |
if not document_title or not str(document_title).strip(): | |
return "Error: Document title cannot be empty" | |
if not document_content or not str(document_content).strip(): | |
return "Error: Document content cannot be empty" | |
# Convert to strings to handle any data type | |
query = str(query) | |
document_title = str(document_title) | |
document_content = str(document_content) | |
# Load model if needed | |
if current_model is None: | |
if checkpoint == "latest": | |
# Find latest checkpoint | |
final_path = os.path.join(OUTPUT_DIR, "final") | |
if os.path.exists(final_path): | |
checkpoint_path = final_path | |
else: | |
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*")) | |
if checkpoints: | |
checkpoint_path = max(checkpoints, key=os.path.getctime) | |
else: | |
checkpoint_path = None | |
load_model_and_tokenizer(checkpoint_path) | |
elif checkpoint == "current": | |
# Use current model without loading (for CSV inference) | |
pass | |
else: | |
load_model_and_tokenizer(checkpoint) | |
# Ensure model is in eval mode | |
current_model.eval() | |
# Create prompt | |
prompt = format_prompt(query, document_title, document_content) | |
# Tokenize | |
inputs = current_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
inputs = {k: v.to(current_model.device) for k, v in inputs.items()} | |
# Get logits for yes/no decision | |
with torch.no_grad(): | |
# Disable cache for inference to avoid DynamicCache issues | |
if hasattr(current_model, 'config'): | |
current_model.config.use_cache = False | |
try: | |
outputs = current_model(**inputs, use_cache=False) | |
except TypeError: | |
# If use_cache parameter not accepted, try without it | |
outputs = current_model(**inputs) | |
logits = outputs.logits[0, -1, :] # Last token logits | |
# Get token IDs for "yes" and "no" | |
yes_token_id = current_tokenizer.encode("yes", add_special_tokens=False)[0] | |
no_token_id = current_tokenizer.encode("no", add_special_tokens=False)[0] | |
# Get logits for yes/no | |
yes_logit = logits[yes_token_id].item() | |
no_logit = logits[no_token_id].item() | |
# Apply softmax to get probabilities | |
probs = torch.softmax(torch.tensor([no_logit, yes_logit]), dim=0) | |
no_prob, yes_prob = probs.tolist() | |
# Return prediction with confidence | |
prediction = "yes" if yes_prob > no_prob else "no" | |
confidence = max(yes_prob, no_prob) | |
return f"{prediction} (confidence: {confidence:.1%})" | |
def run_inference_by_row(split_choice, row_idx, checkpoint="latest"): | |
"""Run inference on a specific row from train/test split""" | |
global train_df, test_df | |
df = train_df if split_choice == "train" else test_df | |
if df is None or row_idx >= len(df): | |
return "Invalid selection. Please load data first." | |
row = df.iloc[int(row_idx)] | |
# Get data from row - handle both old and new formats | |
if all(col in row for col in ['query_text', 'title', 'text']): | |
query = str(row['query_text']) | |
title = str(row['title']) | |
content = str(row['text']) | |
elif all(col in row for col in ['query', 'title', 'content']): | |
query = str(row['query']) | |
title = str(row['title']) | |
content = str(row['content']) | |
else: | |
return "Row missing required columns" | |
prediction_with_confidence = run_inference(query, title, content, checkpoint) | |
# Check if inference returned an error | |
if prediction_with_confidence.startswith("Error:"): | |
return prediction_with_confidence | |
actual = row['label'] | |
# Extract just the prediction (yes/no) from the result | |
prediction = prediction_with_confidence.split()[0] # Gets "yes" or "no" from "yes (confidence: X%)" | |
# Handle 4-category labels | |
if actual in ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative']: | |
# Map to yes/no for comparison | |
label_mapping = { | |
'easy_positive': 'yes', | |
'hard_positive': 'yes', | |
'easy_negative': 'no', | |
'hard_negative': 'no' | |
} | |
mapped_actual = label_mapping[actual] | |
is_correct = prediction == mapped_actual | |
return f"Prediction: {prediction_with_confidence}\nActual: {actual} (mapped to: {mapped_actual})\nCorrect: {is_correct}" | |
else: | |
# Old format with yes/no | |
is_correct = prediction == actual | |
return f"Prediction: {prediction_with_confidence}\nActual: {actual}\nCorrect: {is_correct}" | |
def run_csv_inference(csv_file, model_choice="finetuned", checkpoint_path=None, inference_batch_size=16, progress=gr.Progress()): | |
"""Run inference on an entire CSV file""" | |
global current_model, current_tokenizer | |
if csv_file is None: | |
return None, "Please upload a CSV file" | |
try: | |
# Clear CUDA cache before starting | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Load CSV | |
df = pd.read_csv(csv_file.name) | |
logger.info(f"Loaded {len(df)} samples from CSV for inference") | |
# Clean data - replace NaN values with empty strings | |
text_columns = ['query_text', 'title', 'text', 'query', 'content'] | |
for col in text_columns: | |
if col in df.columns: | |
df[col] = df[col].fillna('') | |
df[col] = df[col].astype(str) | |
# Check required columns - support both old and new formats | |
new_format_cols = ['query_text', 'title', 'text'] | |
old_format_cols = ['query', 'title', 'content'] | |
if all(col in df.columns for col in new_format_cols): | |
# New format | |
query_col = 'query_text' | |
content_col = 'text' | |
elif all(col in df.columns for col in old_format_cols): | |
# Old format | |
query_col = 'query' | |
content_col = 'content' | |
else: | |
return None, f"CSV must have columns: {new_format_cols} or {old_format_cols}" | |
# Load appropriate model | |
if model_choice == "base": | |
# Load base model without any checkpoint | |
load_model_and_tokenizer(checkpoint_path=None) | |
logger.info("Using base model for inference") | |
else: | |
# Load fine-tuned model | |
if checkpoint_path and checkpoint_path.strip(): | |
# Use custom checkpoint path (could be local or HF Hub) | |
load_model_and_tokenizer(checkpoint_path=checkpoint_path) | |
else: | |
# Try to find model - check local first, then HF Hub | |
final_path = os.path.join(OUTPUT_DIR, "final") | |
if os.path.exists(final_path): | |
load_model_and_tokenizer(checkpoint_path=final_path) | |
else: | |
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*")) | |
if checkpoints: | |
latest_checkpoint = max(checkpoints, key=os.path.getctime) | |
load_model_and_tokenizer(checkpoint_path=latest_checkpoint) | |
else: | |
# Try to load from HF Hub using known model names | |
hub_attempts = [] | |
if HF_USERNAME: | |
hub_attempts.append(f"{HF_USERNAME}/phi3-dpo-relevance") | |
# Also try the known pushed model | |
hub_attempts.extend(["amos1088/phi3-dpo-relevance"]) | |
model_loaded = False | |
for hub_model_name in hub_attempts: | |
try: | |
logger.info(f"Trying to load from HF Hub: {hub_model_name}") | |
# Load base model first | |
load_model_and_tokenizer(checkpoint_path=None) | |
# Then load adapter from Hub | |
from peft import PeftModel | |
current_model = PeftModel.from_pretrained(current_model, hub_model_name) | |
current_model.eval() | |
logger.info(f"Successfully loaded fine-tuned model from HF Hub: {hub_model_name}") | |
model_loaded = True | |
break | |
except Exception as e: | |
logger.warning(f"Failed to load {hub_model_name}: {e}") | |
continue | |
if not model_loaded: | |
return None, "No fine-tuned model found locally or on HF Hub. Please train a model first." | |
logger.info("Using fine-tuned model for inference") | |
# Prepare results | |
predictions = [] | |
confidences = [] | |
# Batch inference for speed | |
# For GPT-OSS-20B, we can handle larger batches for inference than training | |
# since we don't need gradients or optimizer states | |
if current_model_id == "openai/gpt-oss-20b": | |
# Inference uses less memory than training, can use larger batches | |
batch_size = min(inference_batch_size, 64) | |
else: | |
batch_size = inference_batch_size | |
logger.info(f"Using batch size {batch_size} for inference") | |
total_batches = (len(df) + batch_size - 1) // batch_size | |
progress(0, desc="Starting batch inference...") | |
for batch_idx in range(0, len(df), batch_size): | |
batch_end = min(batch_idx + batch_size, len(df)) | |
batch_df = df.iloc[batch_idx:batch_end] | |
# Update progress | |
progress((batch_idx + batch_size) / len(df), | |
desc=f"Processing batch {batch_idx//batch_size + 1}/{total_batches}") | |
# Prepare batch prompts | |
prompts = [] | |
for _, row in batch_df.iterrows(): | |
prompt = format_prompt( | |
str(row[query_col]), | |
str(row['title']), | |
str(row[content_col]) | |
) | |
prompts.append(prompt) | |
# Tokenize batch | |
inputs = current_tokenizer( | |
prompts, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=256 | |
) | |
inputs = {k: v.to(current_model.device) for k, v in inputs.items()} | |
# Run batch inference | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(): # Mixed precision for speed | |
outputs = current_model(**inputs) | |
logits = outputs.logits[:, -1, :] # Last token logits for each sample | |
# Get yes/no token IDs | |
yes_token_id = current_tokenizer.encode("yes", add_special_tokens=False)[0] | |
no_token_id = current_tokenizer.encode("no", add_special_tokens=False)[0] | |
# Process each sample in batch | |
for i in range(len(batch_df)): | |
yes_logit = logits[i, yes_token_id].item() | |
no_logit = logits[i, no_token_id].item() | |
# Apply softmax | |
probs = torch.softmax(torch.tensor([no_logit, yes_logit]), dim=0) | |
no_prob, yes_prob = probs.tolist() | |
# Get prediction | |
prediction = "yes" if yes_prob > no_prob else "no" | |
confidence = max(yes_prob, no_prob) | |
predictions.append(prediction) | |
confidences.append(confidence) | |
# Clear cache after each batch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Clear CUDA cache after inference loop | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Add predictions to dataframe | |
df['prediction'] = predictions | |
df['confidence'] = confidences | |
# If labels exist, calculate accuracy and add is_right column | |
if 'label' in df.columns: | |
# Handle 4-category labels | |
if df['label'].iloc[0] in ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative']: | |
# Map 4 categories to yes/no for accuracy calculation | |
label_mapping = { | |
'easy_positive': 'yes', | |
'hard_positive': 'yes', | |
'easy_negative': 'no', | |
'hard_negative': 'no' | |
} | |
df['mapped_label'] = df['label'].map(label_mapping) | |
df['is_right'] = df['prediction'] == df['mapped_label'] | |
# Calculate per-category accuracy | |
accuracy_text = "\n=== Overall Results ===" | |
accuracy = df['is_right'].mean() | |
accuracy_text += f"\nOverall Accuracy: {accuracy:.2%} ({df['is_right'].sum()}/{len(df)} correct)" | |
accuracy_text += "\n\n=== Per-Category Results ===" | |
for category in ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative']: | |
category_df = df[df['label'] == category] | |
if len(category_df) > 0: | |
category_acc = category_df['is_right'].mean() | |
category_count = len(category_df) | |
correct_count = category_df['is_right'].sum() | |
accuracy_text += f"\n{category}: {category_acc:.2%} ({correct_count}/{category_count} correct)" | |
else: | |
# Old format with yes/no labels | |
df['is_right'] = df['prediction'] == df['label'] | |
accuracy = df['is_right'].mean() | |
accuracy_text = f"\nAccuracy: {accuracy:.2%} ({df['is_right'].sum()}/{len(df)} correct)" | |
else: | |
accuracy_text = "" | |
# Save results | |
output_filename = f"inference_results_{model_choice}_{'%Y%m%d_%H%M%S'}.csv" | |
import datetime | |
output_filename = f"inference_results_{model_choice}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" | |
df.to_csv(output_filename, index=False) | |
# Create summary | |
summary = f"Completed inference on {len(df)} samples using {model_choice} model" | |
summary += accuracy_text | |
summary += f"\n\nPrediction distribution:\n" | |
for pred, count in df['prediction'].value_counts().items(): | |
summary += f" {pred}: {count} ({count/len(df):.1%})\n" | |
summary += f"\nResults saved to: {output_filename}" | |
return output_filename, summary | |
except Exception as e: | |
logger.error(f"CSV inference failed: {str(e)}") | |
return None, f"Error during inference: {str(e)}" | |
# ==== GRADIO UI ==== | |
with gr.Blocks(title="Phi-3 Document Relevance Classifier") as demo: | |
gr.Markdown("# Phi-3 Document Relevance Classifier") | |
gr.Markdown("Train and test a model to classify if documents are relevant to queries") | |
with gr.Tab("Training"): | |
csv_input = gr.File(label="Upload Training CSV", file_types=[".csv"]) | |
gr.Markdown("CSV should have columns: query_text, title, text, label (easy_positive/hard_positive/easy_negative/hard_negative)\n**Note: Data should be pre-balanced with equal samples per category**") | |
shuffle_flag = gr.Checkbox(label="Shuffle Dataset", value=True) | |
split_slider = gr.Slider(0.5, 0.9, value=0.8, step=0.05, label="Train Split %") | |
start_btn = gr.Button("Start Training", variant="primary") | |
status_text = gr.Textbox(label="Training Status", lines=10) | |
def start_training_ui(csv_file, shuffle, split_ratio): | |
if csv_file is None: | |
return "Please upload a CSV file first." | |
threading.Thread( | |
target=run_training, | |
args=(csv_file.name, shuffle, split_ratio) | |
).start() | |
return "Training started... Check status below." | |
def get_training_status(): | |
with training_lock: | |
status = training_status["status"] | |
progress = training_status["progress"] | |
logs = "\n".join(training_status["logs"][-20:]) | |
gpu_status = get_gpu_memory_status() | |
return f"Status: {status}\nProgress: {progress}%\n{gpu_status}\n\nLogs:\n{logs}" | |
start_btn.click( | |
start_training_ui, | |
inputs=[csv_input, shuffle_flag, split_slider], | |
outputs=status_text | |
) | |
# Auto-refresh status | |
status_text.change( | |
lambda: get_training_status(), | |
inputs=[], | |
outputs=status_text, | |
every=10 | |
) | |
with gr.Tab("Inference"): | |
query_input = gr.Textbox(label="Query", placeholder="Enter your search query") | |
title_input = gr.Textbox(label="Document Title", placeholder="Enter document title") | |
content_input = gr.Textbox( | |
label="Document Content", | |
placeholder="Enter document content", | |
lines=5 | |
) | |
checkpoint_dropdown = gr.Textbox( | |
label="Checkpoint", | |
value="latest", | |
placeholder="latest or path to checkpoint" | |
) | |
inference_btn = gr.Button("Classify", variant="primary") | |
output_text = gr.Textbox(label="Relevance", lines=2) | |
inference_btn.click( | |
run_inference, | |
inputs=[query_input, title_input, content_input, checkpoint_dropdown], | |
outputs=output_text | |
) | |
with gr.Tab("Test on Dataset"): | |
split_choice = gr.Radio( | |
choices=["train", "test"], | |
value="test", | |
label="Dataset Split" | |
) | |
row_idx_input = gr.Number( | |
label="Row Index (0-based)", | |
value=0, | |
precision=0 | |
) | |
checkpoint_input2 = gr.Textbox( | |
label="Checkpoint", | |
value="latest" | |
) | |
row_inference_btn = gr.Button("Test Row", variant="primary") | |
row_output = gr.Textbox(label="Result", lines=4) | |
row_inference_btn.click( | |
run_inference_by_row, | |
inputs=[split_choice, row_idx_input, checkpoint_input2], | |
outputs=row_output | |
) | |
with gr.Tab("CSV Inference"): | |
gr.Markdown("### Batch Inference on CSV Files") | |
gr.Markdown("Run inference on multiple documents from a CSV file. Compare base model vs fine-tuned model performance.") | |
csv_inference_input = gr.File( | |
label="Upload CSV for Inference", | |
file_types=[".csv"], | |
file_count="single" | |
) | |
gr.Markdown("CSV must have columns: `query_text`, `title`, `text`. Optional: `label` (for accuracy calculation)") | |
with gr.Row(): | |
model_choice_radio = gr.Radio( | |
choices=["base", "finetuned"], | |
value="finetuned", | |
label="Model Selection", | |
info="Choose between base model or fine-tuned model" | |
) | |
custom_checkpoint_path = gr.Textbox( | |
label="Custom Checkpoint Path (Optional)", | |
placeholder="e.g., amos1088/phi3-dpo-relevance or local path", | |
value="amos1088/phi3-dpo-relevance", | |
visible=True | |
) | |
inference_batch_size_slider = gr.Slider( | |
minimum=8, | |
maximum=128, | |
value=32, | |
step=8, | |
label="Inference Batch Size", | |
info="Default matches training batch size (32). Larger = faster but uses more memory" | |
) | |
csv_inference_btn = gr.Button("Run Batch Inference", variant="primary") | |
with gr.Row(): | |
csv_output_file = gr.File( | |
label="Download Results", | |
visible=True | |
) | |
csv_results_text = gr.Textbox( | |
label="Inference Summary", | |
lines=10, | |
max_lines=20 | |
) | |
# Handle inference | |
csv_inference_btn.click( | |
run_csv_inference, | |
inputs=[csv_inference_input, model_choice_radio, custom_checkpoint_path, inference_batch_size_slider], | |
outputs=[csv_output_file, csv_results_text] | |
) | |
with gr.Tab("Trained Models"): | |
gr.Markdown("### Model Training History") | |
gr.Markdown("View all models trained in this session and their performance metrics") | |
models_display = gr.Markdown(get_trained_models_list()) | |
refresh_btn = gr.Button("Refresh Model List", variant="secondary") | |
# Update CSV inference dropdown with trained models | |
model_selector = gr.Dropdown( | |
label="Select Trained Model for Inference", | |
choices=["Latest"] + [m["repo"] for m in trained_models], | |
value="Latest", | |
interactive=True | |
) | |
def refresh_models_list(): | |
models_text = get_trained_models_list() | |
choices = ["Latest", "amos1088/phi3-dpo-relevance"] + [m["repo"] for m in trained_models] | |
return models_text, gr.update(choices=choices) | |
refresh_btn.click( | |
refresh_models_list, | |
outputs=[models_display, model_selector] | |
) | |
# Auto-refresh after training | |
training_status_display = gr.Textbox(visible=False) | |
training_status_display.change( | |
lambda: (get_trained_models_list(), gr.update(choices=["Latest"] + [m["repo"] for m in trained_models])), | |
outputs=[models_display, model_selector] | |
) | |
with gr.Tab("Model Settings"): | |
gr.Markdown("### Model Selection") | |
gr.Markdown("Choose which model to use for training and inference") | |
# Current model display | |
current_model_display = gr.Textbox( | |
label="Current Model", | |
value=f"{AVAILABLE_MODELS[current_model_id]['name']} ({current_model_id})", | |
interactive=False | |
) | |
# Model selection dropdown | |
model_choices = list(AVAILABLE_MODELS.keys()) | |
model_dropdown = gr.Dropdown( | |
choices=model_choices, | |
value=current_model_id, | |
label="Select Model", | |
interactive=True | |
) | |
# Switch model button | |
switch_btn = gr.Button("Switch Model", variant="primary") | |
switch_status = gr.Textbox(label="Status", lines=2) | |
# Model info | |
gr.Markdown("### Model Information") | |
model_info = gr.Textbox( | |
label="Model Details", | |
value=f"Target modules: {AVAILABLE_MODELS[current_model_id]['target_modules']}", | |
lines=3, | |
interactive=False | |
) | |
def handle_model_switch(selected_model): | |
try: | |
# Switch model | |
status = switch_model(selected_model) | |
# Update displays | |
model_name = AVAILABLE_MODELS[selected_model]['name'] | |
current_display = f"{model_name} ({selected_model})" | |
info = f"Target modules: {AVAILABLE_MODELS[selected_model]['target_modules']}\nOutput directory: {OUTPUT_DIR}" | |
return status, current_display, info | |
except Exception as e: | |
return f"Error: {str(e)}", current_model_display.value, model_info.value | |
switch_btn.click( | |
handle_model_switch, | |
inputs=[model_dropdown], | |
outputs=[switch_status, current_model_display, model_info] | |
) | |
if __name__ == "__main__": | |
# Ensure gradio temp directory exists | |
import os | |
os.makedirs("/tmp/gradio", exist_ok=True) | |
demo.launch(server_name="0.0.0.0", server_port=7860) |