train-mbed / app.py
amos1088's picture
no
2d28970
import os
# Set memory optimization environment variables BEFORE importing torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid tokenizer warning
import torch
import pandas as pd
import glob
import threading
import logging
import gradio as gr
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, prepare_model_for_kbit_training, LoraConfig, get_peft_model
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
import warnings
import subprocess
import gc
import psutil
warnings.filterwarnings("ignore")
# ===== CONFIG =====
AVAILABLE_MODELS = {
"microsoft/Phi-3-mini-4k-instruct": {
"name": "Phi-3 Mini 4K",
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"]
},
"openai/gpt-oss-20b": {
"name": "OpenAI GPT-OSS 20B",
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"] # MoE architecture, may need adjustment
}
}
# Default model
current_model_id = "openai/gpt-oss-20b"
MODEL_ID = current_model_id
OUTPUT_DIR = "./phi3-relevance-checkpoints"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
HF_USERNAME = os.environ.get("HF_USERNAME", "")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def kill_gpu_processes():
"""Kill other processes using GPU to free memory"""
if not torch.cuda.is_available():
return
current_pid = os.getpid()
killed = []
try:
# Get GPU processes using nvidia-smi
result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid', '--format=csv,noheader'],
capture_output=True, text=True)
if result.returncode == 0:
pids = [int(pid.strip()) for pid in result.stdout.strip().split('\n') if pid.strip()]
for pid in pids:
if pid == current_pid:
continue
try:
process = psutil.Process(pid)
process_name = process.name()
process.terminate()
killed.append(f"PID {pid} ({process_name})")
logger.info(f"Killed GPU process: {pid} ({process_name})")
except:
pass
except Exception as e:
logger.warning(f"Could not kill GPU processes: {e}")
return killed
def clear_gpu_memory():
"""Aggressively clear GPU memory"""
if torch.cuda.is_available():
# Kill other GPU processes first
kill_gpu_processes()
# Clear PyTorch cache
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Force garbage collection
gc.collect()
# Clear any remaining tensor references
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) and obj.is_cuda:
del obj
except:
pass
gc.collect()
torch.cuda.empty_cache()
logger.info("GPU memory cleared")
# Globals
current_model = None
current_tokenizer = None
train_df = None
test_df = None
training_status = {"status": "idle", "progress": 0, "logs": [], "model_repo": None}
trained_models = [] # Keep track of all trained models
training_lock = threading.Lock()
def format_prompt(query, title, content):
"""Format the prompt for the model"""
# Convert all inputs to strings and handle NaN/None values
query = str(query) if query is not None and str(query) != 'nan' else ""
title = str(title) if title is not None and str(title) != 'nan' else ""
content = str(content) if content is not None and str(content) != 'nan' else ""
# Truncate content if too long
if len(content) > 1000:
content = content[:1000] + "..."
return f"""You would get a query and document's title and content and return yes (if the document is relevant to the query) or no (if the document is not relevant to the query).
Answer only yes / no.
Document:
####DOCUMENT START
title: {title}
content: {content}
####DOCUMENT END
Query:
####Query START
{query}
####Query END
ANSWER: """
def load_model_and_tokenizer(checkpoint_path=None, model_id=None):
"""Load model and tokenizer"""
global current_model, current_tokenizer, current_model_id
if model_id is None:
model_id = current_model_id
# Clear existing model from memory first
if current_model is not None:
del current_model
current_model = None
# Aggressively clear GPU memory before loading new model
clear_gpu_memory()
logger.info(f"Loading model from {checkpoint_path or model_id} ...")
# Load tokenizer
current_tokenizer = AutoTokenizer.from_pretrained(model_id)
if current_tokenizer.pad_token is None:
current_tokenizer.pad_token = current_tokenizer.eos_token
current_tokenizer.padding_side = "left"
# Load model with appropriate settings
model_kwargs = {
"trust_remote_code": True
}
# Model loading configuration
if model_id == "openai/gpt-oss-20b":
# GPT-OSS-20B already uses MXFP4 quantization natively
# Don't apply additional quantization
model_kwargs["torch_dtype"] = torch.bfloat16
model_kwargs["device_map"] = "auto"
model_kwargs["low_cpu_mem_usage"] = True
logger.info("Loading GPT-OSS-20B with native MXFP4 quantization")
else:
# For other models, use 4-bit quantization
model_kwargs["load_in_4bit"] = True
model_kwargs["torch_dtype"] = torch.float16
model_kwargs["device_map"] = "auto"
try:
if checkpoint_path and os.path.exists(checkpoint_path):
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
# Load PEFT adapter
current_model = PeftModel.from_pretrained(base_model, checkpoint_path)
logger.info("Loaded PEFT model from checkpoint")
else:
# Load base model
current_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
logger.info("Loaded base model")
except RuntimeError as e:
if "meta tensor" in str(e):
logger.warning("Meta tensor error detected, trying alternative loading approach...")
# Try loading without device map first, then move to GPU
model_kwargs.pop("device_map", None)
model_kwargs.pop("max_memory", None)
if checkpoint_path and os.path.exists(checkpoint_path):
base_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
base_model = base_model.to("cuda")
current_model = PeftModel.from_pretrained(base_model, checkpoint_path)
else:
current_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
current_model = current_model.to("cuda")
logger.info("Loaded model with alternative approach")
else:
raise
current_model.eval()
return current_model, current_tokenizer
def get_gpu_memory_status():
"""Get current GPU memory usage"""
if not torch.cuda.is_available():
return "No GPU available"
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
return f"GPU Memory: {allocated:.1f}GB allocated / {reserved:.1f}GB reserved / {total:.1f}GB total"
def get_trained_models_list():
"""Get formatted list of all trained models"""
if not trained_models:
return "No models trained yet in this session.\n\nPreviously trained models on HuggingFace:\n- amos1088/phi3-dpo-relevance"
text = "## Trained Models in This Session:\n\n"
for i, model in enumerate(trained_models, 1):
text += f"{i}. **{model['repo']}**\n"
text += f" - Accuracy: {model['accuracy']:.2%}\n"
text += f" - Predictions: Yes {model['yes_ratio']:.1%}, No {model['no_ratio']:.1%}\n"
text += f" - LR: {model.get('lr', 'N/A')}, Model: {model['model_id'].split('/')[-1]}\n"
text += f" - Link: https://huggingface.co/{model['repo']}\n\n"
return text
def switch_model(model_id):
"""Switch to a different model"""
global current_model, current_tokenizer, current_model_id, OUTPUT_DIR
# Clear current model
if current_model is not None:
del current_model
current_model = None
if current_tokenizer is not None:
del current_tokenizer
current_tokenizer = None
# Update model ID and output directory
current_model_id = model_id
model_name = model_id.split("/")[-1]
OUTPUT_DIR = f"./{model_name}-relevance-checkpoints"
# Clear CUDA cache
clear_gpu_memory()
logger.info(f"Switched to model: {model_id}")
return f"Model switched to: {AVAILABLE_MODELS[model_id]['name']}\n{get_gpu_memory_status()}"
def collate_fn(batch):
"""Collate function for DataLoader"""
prompts = [item["prompt"] for item in batch]
labels = [item["label"] for item in batch] # "yes" or "no"
# Create full texts (prompt + label)
full_texts = [prompt + label for prompt, label in zip(prompts, labels)]
# Tokenize full sequences
model_inputs = current_tokenizer(
full_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Create labels by masking the prompt portion (-100 means ignore in loss)
labels = []
for i, prompt in enumerate(prompts):
# Tokenize prompt alone to find where answer starts
prompt_ids = current_tokenizer(
prompt,
add_special_tokens=True,
truncation=True,
max_length=512
).input_ids
# Create label sequence
label_ids = model_inputs.input_ids[i].clone()
# Mask prompt tokens (set to -100)
label_ids[:len(prompt_ids)] = -100
labels.append(label_ids)
labels = torch.stack(labels)
return {
"input_ids": model_inputs.input_ids,
"attention_mask": model_inputs.attention_mask,
"labels": labels
}
def prepare_finetuning_dataset(df):
"""Convert 4-category labels to standard fine-tuning format"""
ft_data = []
# Map 4 categories to yes/no
label_mapping = {
'easy_positive': 'yes',
'hard_positive': 'yes',
'easy_negative': 'no',
'hard_negative': 'no',
'yes': 'yes',
'no': 'no'
}
for _, row in df.iterrows():
# Handle both old and new column names
if 'query_text' in row:
query = str(row['query_text']) if pd.notna(row['query_text']) else ''
title = str(row['title']) if pd.notna(row['title']) else ''
content = str(row['text']) if pd.notna(row['text']) else ''
else:
query = str(row.get('query', '')) if pd.notna(row.get('query', '')) else ''
title = str(row.get('title', '')) if pd.notna(row.get('title', '')) else ''
content = str(row.get('content', '')) if pd.notna(row.get('content', '')) else ''
# Create prompt if not exists
if 'prompt' in row:
prompt = row['prompt']
else:
prompt = format_prompt(query, title, content)
# Get mapped label
original_label = row['label']
mapped_label = label_mapping.get(original_label, original_label)
# Create the full text with prompt and answer
text = prompt + mapped_label
ft_data.append({
'text': text,
'prompt': prompt,
'label': mapped_label,
'original_label': original_label # Keep original for analysis
})
return pd.DataFrame(ft_data)
def train_model(train_df, val_df, epochs=5, batch_size=32, lr=5e-6, max_samples=None):
"""Standard fine-tuning for document relevance classification"""
global current_model, current_tokenizer
# Clear GPU memory before training
logger.info("Clearing GPU memory before training...")
clear_gpu_memory()
# Load model and tokenizer if not already loaded
if current_model is None or current_tokenizer is None:
load_model_and_tokenizer()
# Limit training samples if specified (for memory constraints)
if max_samples and len(train_df) > max_samples:
logger.info(f"Limiting training data from {len(train_df)} to {max_samples} samples")
train_df = train_df.sample(n=max_samples, random_state=42)
val_df = val_df.head(min(len(val_df), max_samples // 5)) # Proportional validation set
# Convert to fine-tuning format
logger.info("Preparing fine-tuning dataset...")
ft_train_df = prepare_finetuning_dataset(train_df)
ft_val_df = prepare_finetuning_dataset(val_df)
# Create datasets
train_dataset = Dataset.from_pandas(ft_train_df)
val_dataset = Dataset.from_pandas(ft_val_df)
# Prepare model for training
if hasattr(current_model, 'is_loaded_in_4bit') and current_model.is_loaded_in_4bit:
# For 4-bit models
current_model = prepare_model_for_kbit_training(current_model)
else:
# For full precision models (including GPT-OSS with native quantization)
# Enable gradient checkpointing to save memory
if hasattr(current_model, 'gradient_checkpointing_enable'):
current_model.gradient_checkpointing_enable()
# For GPT-OSS, also enable some memory optimizations
if current_model_id == "openai/gpt-oss-20b" and hasattr(current_model.config, 'use_cache'):
current_model.config.use_cache = False
# Configure LoRA - can use larger rank on A100
target_modules = AVAILABLE_MODELS[current_model_id]["target_modules"]
if current_model_id == "openai/gpt-oss-20b":
# Conservative for large model
lora_r = 8
lora_alpha = 16
else:
# Can use larger rank for smaller models
lora_r = 16
lora_alpha = 32
peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules
)
logger.info(f"Starting fine-tuning with {len(train_df)} train samples, {len(val_df)} val samples")
logger.info(f"Learning rate: {lr}, Effective batch size: {batch_size}, Epochs: {epochs}")
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Training configuration optimized for standard fine-tuning
target_batch_size = batch_size # Target effective batch size
if current_model_id == "openai/gpt-oss-20b":
# For GPT-OSS-20B: use smaller per-device batch with gradient accumulation
actual_batch_size = 2 # Per-device batch size
seq_length = 512 # Standard sequence length
grad_accum = target_batch_size // actual_batch_size # 16 gradient accumulation steps
else:
# For smaller models like Phi-3 - can use larger per-device batch
actual_batch_size = min(16, target_batch_size) # Up to 16 per device
grad_accum = max(1, target_batch_size // actual_batch_size) # Accumulate if needed
seq_length = 512
# Tokenize the datasets
def tokenize_function(examples):
# Tokenize the full texts (prompt + answer)
model_inputs = current_tokenizer(
examples['text'],
truncation=True,
padding="max_length",
max_length=seq_length,
return_tensors=None
)
# For causal LM, labels are the same as input_ids
model_inputs["labels"] = model_inputs["input_ids"].copy()
# Store metadata for evaluation
model_inputs["original_labels"] = examples['original_label']
model_inputs["mapped_labels"] = examples['label']
return model_inputs
# Tokenize datasets
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
tokenized_val = val_dataset.map(tokenize_function, batched=True, remove_columns=val_dataset.column_names)
# Standard training arguments
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=epochs,
per_device_train_batch_size=actual_batch_size,
per_device_eval_batch_size=actual_batch_size,
gradient_accumulation_steps=grad_accum,
gradient_checkpointing=True,
learning_rate=lr,
lr_scheduler_type="cosine",
warmup_steps=500, # More warmup for standard fine-tuning
logging_steps=10,
save_strategy="epoch",
eval_strategy="epoch",
bf16=True,
fp16=False,
weight_decay=0.01,
optim="adamw_torch",
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to=[],
run_name="standard-ft-relevance",
dataloader_num_workers=2,
)
# Create data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=current_tokenizer,
mlm=False, # Causal LM, not masked LM
pad_to_multiple_of=8
)
# Apply LoRA to the model
current_model = get_peft_model(current_model, peft_config)
current_model.print_trainable_parameters()
# Create standard trainer
trainer = Trainer(
model=current_model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_val,
data_collator=data_collator,
tokenizer=current_tokenizer,
)
# Custom logging callback
def log_callback(logs):
if "loss" in logs:
step = logs.get("epoch", 0)
loss = logs["loss"]
with training_lock:
training_status["logs"].append(f"Step {step}, Loss: {loss:.4f}")
if "eval_loss" in logs:
training_status["logs"].append(f"Eval Loss: {logs['eval_loss']:.4f}")
# Custom callback for status updates
from transformers import TrainerCallback
import numpy as np
def compute_accuracy_metrics(trainer, eval_dataset, ft_val_df, num_samples=100):
"""Compute accuracy metrics and confusion matrix on a subset of eval data"""
logger.info("Computing accuracy metrics...")
try:
# Get the original dataframe for easier access to prompts and labels
eval_df = ft_val_df
# Sample subset for faster evaluation
sample_size = min(num_samples, len(eval_df))
sample_df = eval_df.sample(n=sample_size, random_state=42)
# Initialize confusion matrix counters
confusion_matrix = {
'easy_positive': {'yes': 0, 'no': 0},
'hard_positive': {'yes': 0, 'no': 0},
'easy_negative': {'yes': 0, 'no': 0},
'hard_negative': {'yes': 0, 'no': 0}
}
predictions_yes = 0
predictions_no = 0
correct = 0
for idx, row in sample_df.iterrows():
prompt = row['prompt']
true_label = row['label'] # This is the mapped label (yes/no)
original_label = row['original_label'] # Get original 4-category label
# Tokenize and run inference
inputs = current_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(trainer.model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = trainer.model(**inputs)
logits = outputs.logits[0, -1, :]
# Get token IDs
yes_token_id = current_tokenizer.encode("yes", add_special_tokens=False)[0]
no_token_id = current_tokenizer.encode("no", add_special_tokens=False)[0]
yes_logit = logits[yes_token_id].item()
no_logit = logits[no_token_id].item()
# Get prediction
prediction = "yes" if yes_logit > no_logit else "no"
if prediction == "yes":
predictions_yes += 1
else:
predictions_no += 1
if prediction == true_label:
correct += 1
# Update confusion matrix if we have original label
if original_label and original_label in confusion_matrix:
confusion_matrix[original_label][prediction] += 1
accuracy = correct / len(sample_df)
yes_ratio = predictions_yes / len(sample_df)
no_ratio = predictions_no / len(sample_df)
# Calculate per-category accuracies
category_accuracies = {}
for category in confusion_matrix:
total = confusion_matrix[category]['yes'] + confusion_matrix[category]['no']
if total > 0:
if category in ['easy_positive', 'hard_positive']:
# For positive categories, correct prediction is 'yes'
category_accuracies[category] = confusion_matrix[category]['yes'] / total
else:
# For negative categories, correct prediction is 'no'
category_accuracies[category] = confusion_matrix[category]['no'] / total
else:
category_accuracies[category] = 0.0
return {
'accuracy': accuracy,
'yes_ratio': yes_ratio,
'no_ratio': no_ratio,
'total_samples': len(sample_df),
'confusion_matrix': confusion_matrix,
'category_accuracies': category_accuracies
}
except Exception as e:
logger.error(f"Error computing accuracy metrics: {e}")
import traceback
traceback.print_exc()
return {
'accuracy': 0.0,
'yes_ratio': 0.0,
'no_ratio': 0.0,
'total_samples': 0,
'confusion_matrix': {},
'category_accuracies': {}
}
class StatusCallback(TrainerCallback):
def __init__(self, trainer, eval_dataset, ft_val_df):
self.trainer = trainer
self.eval_dataset = eval_dataset
self.ft_val_df = ft_val_df
self.eval_every_n_steps = 50 # Evaluate every 50 steps
self.last_eval_step = 0
def on_step_end(self, args, state, control, **kwargs):
"""Called at the end of each training step"""
if state.global_step > 0 and state.global_step % self.eval_every_n_steps == 0 and state.global_step != self.last_eval_step:
self.last_eval_step = state.global_step
logger.info(f"on_step_end: Step {state.global_step} - Computing metrics")
with training_lock:
training_status["logs"].append(f"[on_step_end] Computing accuracy at step {state.global_step}...")
try:
metrics = compute_accuracy_metrics(self.trainer, self.eval_dataset, self.ft_val_df)
training_status["logs"].append(
f"Step {state.global_step} Metrics: "
f"Accuracy={metrics['accuracy']:.2%}, "
f"Yes={metrics['yes_ratio']:.1%}, "
f"No={metrics['no_ratio']:.1%}"
)
# Add confusion matrix info if available
if 'confusion_matrix' in metrics and metrics['confusion_matrix']:
training_status["logs"].append("\n=== Confusion Matrix ===")
for category, preds in metrics['confusion_matrix'].items():
total = preds['yes'] + preds['no']
if total > 0:
acc = metrics['category_accuracies'][category]
training_status["logs"].append(
f"{category}: Yes={preds['yes']}, No={preds['no']} (Acc: {acc:.1%})"
)
except Exception as e:
logger.error(f"Error in on_step_end: {e}")
training_status["logs"].append(f"Error computing metrics: {e}")
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
with training_lock:
if "loss" in logs:
training_status["logs"].append(f"Step {state.global_step}: Loss = {logs['loss']:.4f}")
if "eval_loss" in logs:
training_status["logs"].append(f"Eval Loss = {logs['eval_loss']:.4f}")
# Compute accuracy metrics periodically
if state.global_step > 0 and state.global_step % self.eval_every_n_steps == 0:
logger.info(f"Step {state.global_step}: Computing accuracy metrics...")
training_status["logs"].append(f"Computing accuracy at step {state.global_step}...")
metrics = compute_accuracy_metrics(self.trainer, self.eval_dataset, self.ft_val_df)
training_status["logs"].append(
f"Step {state.global_step} Metrics: "
f"Accuracy={metrics['accuracy']:.2%}, "
f"Yes={metrics['yes_ratio']:.1%}, "
f"No={metrics['no_ratio']:.1%}"
)
# Add confusion matrix info if available
if 'confusion_matrix' in metrics:
training_status["logs"].append("\n=== Confusion Matrix ===")
for category, preds in metrics['confusion_matrix'].items():
total = preds['yes'] + preds['no']
if total > 0:
acc = metrics['category_accuracies'][category]
training_status["logs"].append(
f"{category}: Yes={preds['yes']}, No={preds['no']} (Acc: {acc:.1%})"
)
# Warn if model is biased
if metrics['yes_ratio'] < 0.2 or metrics['no_ratio'] < 0.2:
training_status["logs"].append(
f"⚠️ WARNING: Model is heavily biased! "
f"(Yes: {metrics['yes_ratio']:.1%}, No: {metrics['no_ratio']:.1%})"
)
# Update progress
if state.global_step > 0:
total_steps = len(train_dataset) // batch_size * epochs
training_status["progress"] = min(int((state.global_step / total_steps) * 100), 99)
# Add callback with trainer and eval dataset
status_callback = StatusCallback(trainer, val_dataset, ft_val_df)
trainer.add_callback(status_callback)
# Train
try:
logger.info("Starting fine-tuning...")
trainer.train()
# Save final model
save_path = os.path.join(OUTPUT_DIR, "final")
trainer.save_model(save_path)
current_tokenizer.save_pretrained(save_path)
logger.info(f"Model saved to {save_path}")
# Compute final metrics
logger.info("Computing final accuracy metrics...")
final_metrics = compute_accuracy_metrics(trainer, val_dataset, ft_val_df, num_samples=200)
logger.info(f"Final Accuracy: {final_metrics['accuracy']:.2%}")
logger.info(f"Final Prediction Distribution - Yes: {final_metrics['yes_ratio']:.1%}, No: {final_metrics['no_ratio']:.1%}")
with training_lock:
training_status["logs"].append(f"\n=== FINAL RESULTS ===")
training_status["logs"].append(f"Overall Accuracy: {final_metrics['accuracy']:.2%}")
training_status["logs"].append(f"Yes predictions: {final_metrics['yes_ratio']:.1%}")
training_status["logs"].append(f"No predictions: {final_metrics['no_ratio']:.1%}")
# Add final confusion matrix
if 'confusion_matrix' in final_metrics:
training_status["logs"].append("\n=== Final Confusion Matrix ===")
for category, preds in final_metrics['confusion_matrix'].items():
total = preds['yes'] + preds['no']
if total > 0:
acc = final_metrics['category_accuracies'][category]
training_status["logs"].append(
f"{category}: Yes={preds['yes']}, No={preds['no']} (Accuracy: {acc:.1%})"
)
# Update global model reference
current_model = trainer.model
current_model.eval()
# Push to hub if token available
if HF_TOKEN and HF_USERNAME:
try:
# Generate unique repo name with timestamp
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_short_name = current_model_id.split("/")[-1]
# Create descriptive repo name with training details
repo_name = f"{HF_USERNAME}/{model_short_name}-relevance-ft-{timestamp}"
# Create model card with training information
model_card_content = f"""---
tags:
- document-relevance
- dpo
- {model_short_name}
datasets:
- custom-relevance-dataset
metrics:
- accuracy
model-index:
- name: {repo_name.split('/')[-1]}
results:
- task:
type: text-classification
name: Document Relevance Classification
metrics:
- type: accuracy
value: {final_metrics['accuracy']:.4f}
name: Validation Accuracy
- type: yes_ratio
value: {final_metrics['yes_ratio']:.4f}
name: Yes Prediction Ratio
- type: no_ratio
value: {final_metrics['no_ratio']:.4f}
name: No Prediction Ratio
---
# {model_short_name} Document Relevance Classifier
This model was trained using standard fine-tuning for document relevance classification.
## Training Configuration
- Base Model: {current_model_id}
- Training Type: Standard Fine-tuning
- Learning Rate: {training_args.learning_rate}
- Batch Size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}
- Epochs: {training_args.num_train_epochs}
- Training Samples: {len(train_df)}
- Validation Samples: {len(val_df)}
## Performance Metrics
- **Accuracy**: {final_metrics['accuracy']:.2%}
- **Yes Predictions**: {final_metrics['yes_ratio']:.1%}
- **No Predictions**: {final_metrics['no_ratio']:.1%}
## Usage
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Load base model
model = AutoModelForCausalLM.from_pretrained("{current_model_id}")
tokenizer = AutoTokenizer.from_pretrained("{current_model_id}")
# Load adapter
model = PeftModel.from_pretrained(model, "{HF_USERNAME}/{repo_name.split('/')[-1]}")
```
## Training Date
{datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')}
"""
# Push model with model card
current_model.push_to_hub(
repo_name,
use_auth_token=HF_TOKEN,
commit_message=f"Standard fine-tuning with lr={training_args.learning_rate}, accuracy={final_metrics['accuracy']:.2%}"
)
current_tokenizer.push_to_hub(repo_name, use_auth_token=HF_TOKEN)
# Save model card
try:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=model_card_content.encode(),
path_in_repo="README.md",
repo_id=repo_name,
repo_type="model",
token=HF_TOKEN
)
except:
pass # Model card upload is optional
logger.info(f"Model pushed to hub: {repo_name}")
# Save repo name to training status and trained models list
with training_lock:
training_status["logs"].append(f"Model saved to: https://huggingface.co/{repo_name}")
training_status["model_repo"] = repo_name
# Add to trained models list
trained_models.append({
"repo": repo_name,
"timestamp": timestamp,
"accuracy": final_metrics['accuracy'],
"yes_ratio": final_metrics['yes_ratio'],
"no_ratio": final_metrics['no_ratio'],
"lr": training_args.learning_rate,
"model_id": current_model_id
})
except Exception as e:
logger.error(f"Failed to push to hub: {e}")
except Exception as e:
logger.error(f"Training failed: {e}")
raise
def run_training(csv_path, shuffle_flag=False, split_ratio=0.8):
"""Run training from CSV file with balanced sampling"""
global train_df, test_df, training_status
try:
with training_lock:
training_status["status"] = "loading"
training_status["logs"] = ["Loading data..."]
# Load CSV
df = pd.read_csv(csv_path)
logger.info(f"Loaded {len(df)} samples from {csv_path}")
# Clean data - replace NaN values with empty strings
text_columns = ['query_text', 'title', 'text', 'query', 'content']
for col in text_columns:
if col in df.columns:
# Count NaN values before cleaning
nan_count = df[col].isna().sum()
if nan_count > 0:
logger.warning(f"Found {nan_count} NaN values in column '{col}' - replacing with empty strings")
df[col] = df[col].fillna('')
df[col] = df[col].astype(str)
# Check required columns for new format
new_format_cols = ['query_text', 'title', 'text', 'label']
old_format_cols = ['query', 'title', 'content', 'label']
if all(col in df.columns for col in new_format_cols):
# New format with 4 categories
logger.info("Using new CSV format with 4 categories")
# Validate labels
valid_labels = ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative']
if not all(label in valid_labels for label in df['label'].unique()):
raise ValueError(f"Labels must be one of: {valid_labels}")
# Create prompt column
if 'prompt' not in df.columns:
logger.info("Creating prompt column from query_text/title/text")
df['prompt'] = df.apply(
lambda row: format_prompt(row['query_text'], row['title'], row['text']),
axis=1
)
# Log label distribution (assuming data is already balanced)
logger.info(f"Label distribution: {df['label'].value_counts().to_dict()}")
elif all(col in df.columns for col in old_format_cols):
# Old format
logger.info("Using old CSV format")
if 'prompt' not in df.columns:
df['prompt'] = df.apply(
lambda row: format_prompt(row['query'], row['title'], row['content']),
axis=1
)
# Validate labels for old format
if not all(label in ['yes', 'no'] for label in df['label'].unique()):
raise ValueError("Labels must be 'yes' or 'no' for old format")
else:
raise ValueError(f"CSV must have columns: {new_format_cols} or {old_format_cols}")
# Shuffle if requested
if shuffle_flag:
df = df.sample(frac=1).reset_index(drop=True)
# Split data
split_idx = int(len(df) * split_ratio)
train_df = df.iloc[:split_idx].reset_index(drop=True)
test_df = df.iloc[split_idx:].reset_index(drop=True)
logger.info(f"Train: {len(train_df)}, Test: {len(test_df)}")
logger.info(f"Train label distribution: {train_df['label'].value_counts().to_dict()}")
# Start training
with training_lock:
training_status["status"] = "training"
training_status["logs"].append("Starting training...")
training_status["logs"].append(f"Train samples: {len(train_df)}")
training_status["logs"].append(f"Test samples: {len(test_df)}")
# Conservative sample size for GPT-OSS-20B
if current_model_id == "openai/gpt-oss-20b":
max_samples = 2000 # Start conservative
else:
max_samples = None
train_model(train_df, test_df, epochs=5, batch_size=32, lr=5e-6, max_samples=max_samples)
with training_lock:
training_status["status"] = "completed"
training_status["logs"].append("Training completed!")
training_status["progress"] = 100
except Exception as e:
logger.error(f"Training failed: {str(e)}")
with training_lock:
training_status["status"] = "failed"
training_status["logs"].append(f"Error: {str(e)}")
training_status["error"] = str(e)
def run_inference(query, document_title, document_content, checkpoint="latest"):
"""Run inference on a single example using logit comparison"""
global current_model, current_tokenizer
# Validate inputs
if not query or not str(query).strip():
return "Error: Query cannot be empty"
if not document_title or not str(document_title).strip():
return "Error: Document title cannot be empty"
if not document_content or not str(document_content).strip():
return "Error: Document content cannot be empty"
# Convert to strings to handle any data type
query = str(query)
document_title = str(document_title)
document_content = str(document_content)
# Load model if needed
if current_model is None:
if checkpoint == "latest":
# Find latest checkpoint
final_path = os.path.join(OUTPUT_DIR, "final")
if os.path.exists(final_path):
checkpoint_path = final_path
else:
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*"))
if checkpoints:
checkpoint_path = max(checkpoints, key=os.path.getctime)
else:
checkpoint_path = None
load_model_and_tokenizer(checkpoint_path)
elif checkpoint == "current":
# Use current model without loading (for CSV inference)
pass
else:
load_model_and_tokenizer(checkpoint)
# Ensure model is in eval mode
current_model.eval()
# Create prompt
prompt = format_prompt(query, document_title, document_content)
# Tokenize
inputs = current_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(current_model.device) for k, v in inputs.items()}
# Get logits for yes/no decision
with torch.no_grad():
# Disable cache for inference to avoid DynamicCache issues
if hasattr(current_model, 'config'):
current_model.config.use_cache = False
try:
outputs = current_model(**inputs, use_cache=False)
except TypeError:
# If use_cache parameter not accepted, try without it
outputs = current_model(**inputs)
logits = outputs.logits[0, -1, :] # Last token logits
# Get token IDs for "yes" and "no"
yes_token_id = current_tokenizer.encode("yes", add_special_tokens=False)[0]
no_token_id = current_tokenizer.encode("no", add_special_tokens=False)[0]
# Get logits for yes/no
yes_logit = logits[yes_token_id].item()
no_logit = logits[no_token_id].item()
# Apply softmax to get probabilities
probs = torch.softmax(torch.tensor([no_logit, yes_logit]), dim=0)
no_prob, yes_prob = probs.tolist()
# Return prediction with confidence
prediction = "yes" if yes_prob > no_prob else "no"
confidence = max(yes_prob, no_prob)
return f"{prediction} (confidence: {confidence:.1%})"
def run_inference_by_row(split_choice, row_idx, checkpoint="latest"):
"""Run inference on a specific row from train/test split"""
global train_df, test_df
df = train_df if split_choice == "train" else test_df
if df is None or row_idx >= len(df):
return "Invalid selection. Please load data first."
row = df.iloc[int(row_idx)]
# Get data from row - handle both old and new formats
if all(col in row for col in ['query_text', 'title', 'text']):
query = str(row['query_text'])
title = str(row['title'])
content = str(row['text'])
elif all(col in row for col in ['query', 'title', 'content']):
query = str(row['query'])
title = str(row['title'])
content = str(row['content'])
else:
return "Row missing required columns"
prediction_with_confidence = run_inference(query, title, content, checkpoint)
# Check if inference returned an error
if prediction_with_confidence.startswith("Error:"):
return prediction_with_confidence
actual = row['label']
# Extract just the prediction (yes/no) from the result
prediction = prediction_with_confidence.split()[0] # Gets "yes" or "no" from "yes (confidence: X%)"
# Handle 4-category labels
if actual in ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative']:
# Map to yes/no for comparison
label_mapping = {
'easy_positive': 'yes',
'hard_positive': 'yes',
'easy_negative': 'no',
'hard_negative': 'no'
}
mapped_actual = label_mapping[actual]
is_correct = prediction == mapped_actual
return f"Prediction: {prediction_with_confidence}\nActual: {actual} (mapped to: {mapped_actual})\nCorrect: {is_correct}"
else:
# Old format with yes/no
is_correct = prediction == actual
return f"Prediction: {prediction_with_confidence}\nActual: {actual}\nCorrect: {is_correct}"
def run_csv_inference(csv_file, model_choice="finetuned", checkpoint_path=None, inference_batch_size=16, progress=gr.Progress()):
"""Run inference on an entire CSV file"""
global current_model, current_tokenizer
if csv_file is None:
return None, "Please upload a CSV file"
try:
# Clear CUDA cache before starting
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load CSV
df = pd.read_csv(csv_file.name)
logger.info(f"Loaded {len(df)} samples from CSV for inference")
# Clean data - replace NaN values with empty strings
text_columns = ['query_text', 'title', 'text', 'query', 'content']
for col in text_columns:
if col in df.columns:
df[col] = df[col].fillna('')
df[col] = df[col].astype(str)
# Check required columns - support both old and new formats
new_format_cols = ['query_text', 'title', 'text']
old_format_cols = ['query', 'title', 'content']
if all(col in df.columns for col in new_format_cols):
# New format
query_col = 'query_text'
content_col = 'text'
elif all(col in df.columns for col in old_format_cols):
# Old format
query_col = 'query'
content_col = 'content'
else:
return None, f"CSV must have columns: {new_format_cols} or {old_format_cols}"
# Load appropriate model
if model_choice == "base":
# Load base model without any checkpoint
load_model_and_tokenizer(checkpoint_path=None)
logger.info("Using base model for inference")
else:
# Load fine-tuned model
if checkpoint_path and checkpoint_path.strip():
# Use custom checkpoint path (could be local or HF Hub)
load_model_and_tokenizer(checkpoint_path=checkpoint_path)
else:
# Try to find model - check local first, then HF Hub
final_path = os.path.join(OUTPUT_DIR, "final")
if os.path.exists(final_path):
load_model_and_tokenizer(checkpoint_path=final_path)
else:
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*"))
if checkpoints:
latest_checkpoint = max(checkpoints, key=os.path.getctime)
load_model_and_tokenizer(checkpoint_path=latest_checkpoint)
else:
# Try to load from HF Hub using known model names
hub_attempts = []
if HF_USERNAME:
hub_attempts.append(f"{HF_USERNAME}/phi3-dpo-relevance")
# Also try the known pushed model
hub_attempts.extend(["amos1088/phi3-dpo-relevance"])
model_loaded = False
for hub_model_name in hub_attempts:
try:
logger.info(f"Trying to load from HF Hub: {hub_model_name}")
# Load base model first
load_model_and_tokenizer(checkpoint_path=None)
# Then load adapter from Hub
from peft import PeftModel
current_model = PeftModel.from_pretrained(current_model, hub_model_name)
current_model.eval()
logger.info(f"Successfully loaded fine-tuned model from HF Hub: {hub_model_name}")
model_loaded = True
break
except Exception as e:
logger.warning(f"Failed to load {hub_model_name}: {e}")
continue
if not model_loaded:
return None, "No fine-tuned model found locally or on HF Hub. Please train a model first."
logger.info("Using fine-tuned model for inference")
# Prepare results
predictions = []
confidences = []
# Batch inference for speed
# For GPT-OSS-20B, we can handle larger batches for inference than training
# since we don't need gradients or optimizer states
if current_model_id == "openai/gpt-oss-20b":
# Inference uses less memory than training, can use larger batches
batch_size = min(inference_batch_size, 64)
else:
batch_size = inference_batch_size
logger.info(f"Using batch size {batch_size} for inference")
total_batches = (len(df) + batch_size - 1) // batch_size
progress(0, desc="Starting batch inference...")
for batch_idx in range(0, len(df), batch_size):
batch_end = min(batch_idx + batch_size, len(df))
batch_df = df.iloc[batch_idx:batch_end]
# Update progress
progress((batch_idx + batch_size) / len(df),
desc=f"Processing batch {batch_idx//batch_size + 1}/{total_batches}")
# Prepare batch prompts
prompts = []
for _, row in batch_df.iterrows():
prompt = format_prompt(
str(row[query_col]),
str(row['title']),
str(row[content_col])
)
prompts.append(prompt)
# Tokenize batch
inputs = current_tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
)
inputs = {k: v.to(current_model.device) for k, v in inputs.items()}
# Run batch inference
with torch.no_grad():
with torch.cuda.amp.autocast(): # Mixed precision for speed
outputs = current_model(**inputs)
logits = outputs.logits[:, -1, :] # Last token logits for each sample
# Get yes/no token IDs
yes_token_id = current_tokenizer.encode("yes", add_special_tokens=False)[0]
no_token_id = current_tokenizer.encode("no", add_special_tokens=False)[0]
# Process each sample in batch
for i in range(len(batch_df)):
yes_logit = logits[i, yes_token_id].item()
no_logit = logits[i, no_token_id].item()
# Apply softmax
probs = torch.softmax(torch.tensor([no_logit, yes_logit]), dim=0)
no_prob, yes_prob = probs.tolist()
# Get prediction
prediction = "yes" if yes_prob > no_prob else "no"
confidence = max(yes_prob, no_prob)
predictions.append(prediction)
confidences.append(confidence)
# Clear cache after each batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Clear CUDA cache after inference loop
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Add predictions to dataframe
df['prediction'] = predictions
df['confidence'] = confidences
# If labels exist, calculate accuracy and add is_right column
if 'label' in df.columns:
# Handle 4-category labels
if df['label'].iloc[0] in ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative']:
# Map 4 categories to yes/no for accuracy calculation
label_mapping = {
'easy_positive': 'yes',
'hard_positive': 'yes',
'easy_negative': 'no',
'hard_negative': 'no'
}
df['mapped_label'] = df['label'].map(label_mapping)
df['is_right'] = df['prediction'] == df['mapped_label']
# Calculate per-category accuracy
accuracy_text = "\n=== Overall Results ==="
accuracy = df['is_right'].mean()
accuracy_text += f"\nOverall Accuracy: {accuracy:.2%} ({df['is_right'].sum()}/{len(df)} correct)"
accuracy_text += "\n\n=== Per-Category Results ==="
for category in ['easy_positive', 'hard_positive', 'easy_negative', 'hard_negative']:
category_df = df[df['label'] == category]
if len(category_df) > 0:
category_acc = category_df['is_right'].mean()
category_count = len(category_df)
correct_count = category_df['is_right'].sum()
accuracy_text += f"\n{category}: {category_acc:.2%} ({correct_count}/{category_count} correct)"
else:
# Old format with yes/no labels
df['is_right'] = df['prediction'] == df['label']
accuracy = df['is_right'].mean()
accuracy_text = f"\nAccuracy: {accuracy:.2%} ({df['is_right'].sum()}/{len(df)} correct)"
else:
accuracy_text = ""
# Save results
output_filename = f"inference_results_{model_choice}_{'%Y%m%d_%H%M%S'}.csv"
import datetime
output_filename = f"inference_results_{model_choice}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
df.to_csv(output_filename, index=False)
# Create summary
summary = f"Completed inference on {len(df)} samples using {model_choice} model"
summary += accuracy_text
summary += f"\n\nPrediction distribution:\n"
for pred, count in df['prediction'].value_counts().items():
summary += f" {pred}: {count} ({count/len(df):.1%})\n"
summary += f"\nResults saved to: {output_filename}"
return output_filename, summary
except Exception as e:
logger.error(f"CSV inference failed: {str(e)}")
return None, f"Error during inference: {str(e)}"
# ==== GRADIO UI ====
with gr.Blocks(title="Phi-3 Document Relevance Classifier") as demo:
gr.Markdown("# Phi-3 Document Relevance Classifier")
gr.Markdown("Train and test a model to classify if documents are relevant to queries")
with gr.Tab("Training"):
csv_input = gr.File(label="Upload Training CSV", file_types=[".csv"])
gr.Markdown("CSV should have columns: query_text, title, text, label (easy_positive/hard_positive/easy_negative/hard_negative)\n**Note: Data should be pre-balanced with equal samples per category**")
shuffle_flag = gr.Checkbox(label="Shuffle Dataset", value=True)
split_slider = gr.Slider(0.5, 0.9, value=0.8, step=0.05, label="Train Split %")
start_btn = gr.Button("Start Training", variant="primary")
status_text = gr.Textbox(label="Training Status", lines=10)
def start_training_ui(csv_file, shuffle, split_ratio):
if csv_file is None:
return "Please upload a CSV file first."
threading.Thread(
target=run_training,
args=(csv_file.name, shuffle, split_ratio)
).start()
return "Training started... Check status below."
def get_training_status():
with training_lock:
status = training_status["status"]
progress = training_status["progress"]
logs = "\n".join(training_status["logs"][-20:])
gpu_status = get_gpu_memory_status()
return f"Status: {status}\nProgress: {progress}%\n{gpu_status}\n\nLogs:\n{logs}"
start_btn.click(
start_training_ui,
inputs=[csv_input, shuffle_flag, split_slider],
outputs=status_text
)
# Auto-refresh status
status_text.change(
lambda: get_training_status(),
inputs=[],
outputs=status_text,
every=10
)
with gr.Tab("Inference"):
query_input = gr.Textbox(label="Query", placeholder="Enter your search query")
title_input = gr.Textbox(label="Document Title", placeholder="Enter document title")
content_input = gr.Textbox(
label="Document Content",
placeholder="Enter document content",
lines=5
)
checkpoint_dropdown = gr.Textbox(
label="Checkpoint",
value="latest",
placeholder="latest or path to checkpoint"
)
inference_btn = gr.Button("Classify", variant="primary")
output_text = gr.Textbox(label="Relevance", lines=2)
inference_btn.click(
run_inference,
inputs=[query_input, title_input, content_input, checkpoint_dropdown],
outputs=output_text
)
with gr.Tab("Test on Dataset"):
split_choice = gr.Radio(
choices=["train", "test"],
value="test",
label="Dataset Split"
)
row_idx_input = gr.Number(
label="Row Index (0-based)",
value=0,
precision=0
)
checkpoint_input2 = gr.Textbox(
label="Checkpoint",
value="latest"
)
row_inference_btn = gr.Button("Test Row", variant="primary")
row_output = gr.Textbox(label="Result", lines=4)
row_inference_btn.click(
run_inference_by_row,
inputs=[split_choice, row_idx_input, checkpoint_input2],
outputs=row_output
)
with gr.Tab("CSV Inference"):
gr.Markdown("### Batch Inference on CSV Files")
gr.Markdown("Run inference on multiple documents from a CSV file. Compare base model vs fine-tuned model performance.")
csv_inference_input = gr.File(
label="Upload CSV for Inference",
file_types=[".csv"],
file_count="single"
)
gr.Markdown("CSV must have columns: `query_text`, `title`, `text`. Optional: `label` (for accuracy calculation)")
with gr.Row():
model_choice_radio = gr.Radio(
choices=["base", "finetuned"],
value="finetuned",
label="Model Selection",
info="Choose between base model or fine-tuned model"
)
custom_checkpoint_path = gr.Textbox(
label="Custom Checkpoint Path (Optional)",
placeholder="e.g., amos1088/phi3-dpo-relevance or local path",
value="amos1088/phi3-dpo-relevance",
visible=True
)
inference_batch_size_slider = gr.Slider(
minimum=8,
maximum=128,
value=32,
step=8,
label="Inference Batch Size",
info="Default matches training batch size (32). Larger = faster but uses more memory"
)
csv_inference_btn = gr.Button("Run Batch Inference", variant="primary")
with gr.Row():
csv_output_file = gr.File(
label="Download Results",
visible=True
)
csv_results_text = gr.Textbox(
label="Inference Summary",
lines=10,
max_lines=20
)
# Handle inference
csv_inference_btn.click(
run_csv_inference,
inputs=[csv_inference_input, model_choice_radio, custom_checkpoint_path, inference_batch_size_slider],
outputs=[csv_output_file, csv_results_text]
)
with gr.Tab("Trained Models"):
gr.Markdown("### Model Training History")
gr.Markdown("View all models trained in this session and their performance metrics")
models_display = gr.Markdown(get_trained_models_list())
refresh_btn = gr.Button("Refresh Model List", variant="secondary")
# Update CSV inference dropdown with trained models
model_selector = gr.Dropdown(
label="Select Trained Model for Inference",
choices=["Latest"] + [m["repo"] for m in trained_models],
value="Latest",
interactive=True
)
def refresh_models_list():
models_text = get_trained_models_list()
choices = ["Latest", "amos1088/phi3-dpo-relevance"] + [m["repo"] for m in trained_models]
return models_text, gr.update(choices=choices)
refresh_btn.click(
refresh_models_list,
outputs=[models_display, model_selector]
)
# Auto-refresh after training
training_status_display = gr.Textbox(visible=False)
training_status_display.change(
lambda: (get_trained_models_list(), gr.update(choices=["Latest"] + [m["repo"] for m in trained_models])),
outputs=[models_display, model_selector]
)
with gr.Tab("Model Settings"):
gr.Markdown("### Model Selection")
gr.Markdown("Choose which model to use for training and inference")
# Current model display
current_model_display = gr.Textbox(
label="Current Model",
value=f"{AVAILABLE_MODELS[current_model_id]['name']} ({current_model_id})",
interactive=False
)
# Model selection dropdown
model_choices = list(AVAILABLE_MODELS.keys())
model_dropdown = gr.Dropdown(
choices=model_choices,
value=current_model_id,
label="Select Model",
interactive=True
)
# Switch model button
switch_btn = gr.Button("Switch Model", variant="primary")
switch_status = gr.Textbox(label="Status", lines=2)
# Model info
gr.Markdown("### Model Information")
model_info = gr.Textbox(
label="Model Details",
value=f"Target modules: {AVAILABLE_MODELS[current_model_id]['target_modules']}",
lines=3,
interactive=False
)
def handle_model_switch(selected_model):
try:
# Switch model
status = switch_model(selected_model)
# Update displays
model_name = AVAILABLE_MODELS[selected_model]['name']
current_display = f"{model_name} ({selected_model})"
info = f"Target modules: {AVAILABLE_MODELS[selected_model]['target_modules']}\nOutput directory: {OUTPUT_DIR}"
return status, current_display, info
except Exception as e:
return f"Error: {str(e)}", current_model_display.value, model_info.value
switch_btn.click(
handle_model_switch,
inputs=[model_dropdown],
outputs=[switch_status, current_model_display, model_info]
)
if __name__ == "__main__":
# Ensure gradio temp directory exists
import os
os.makedirs("/tmp/gradio", exist_ok=True)
demo.launch(server_name="0.0.0.0", server_port=7860)