|
|
import threading |
|
|
import torch |
|
|
import time |
|
|
import json |
|
|
import queue |
|
|
import uuid |
|
|
import matplotlib.pyplot as plt |
|
|
from functools import partial |
|
|
from typing import Generator, Optional, List, Dict, Any, Tuple |
|
|
from datasets import Dataset, load_dataset |
|
|
from trl import SFTConfig, SFTTrainer |
|
|
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl |
|
|
from huggingface_hub import HfApi, model_info, metadata_update |
|
|
|
|
|
from config import AppConfig |
|
|
from tools import DEFAULT_TOOLS |
|
|
from utils import ( |
|
|
authenticate_hf, |
|
|
load_model_and_tokenizer, |
|
|
create_conversation_format, |
|
|
parse_csv_dataset, |
|
|
zip_directory |
|
|
) |
|
|
|
|
|
class AbortCallback(TrainerCallback): |
|
|
def __init__(self, stop_event: threading.Event): |
|
|
self.stop_event = stop_event |
|
|
|
|
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
|
|
if self.stop_event.is_set(): |
|
|
control.should_training_stop = True |
|
|
|
|
|
class LogStreamingCallback(TrainerCallback): |
|
|
def __init__(self, log_queue: queue.Queue): |
|
|
self.log_queue = log_queue |
|
|
|
|
|
def _get_string(self, value): |
|
|
if isinstance(value, float): |
|
|
return f"{value:.4f}" |
|
|
return str(value) |
|
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
|
if not logs: |
|
|
return |
|
|
|
|
|
metrics_map = { |
|
|
"loss": "Loss", |
|
|
"eval_loss": "Eval Loss", |
|
|
"learning_rate": "LR", |
|
|
"epoch": "Epoch" |
|
|
} |
|
|
log_parts = [f"π [Step {state.global_step}]"] |
|
|
|
|
|
for key, label in metrics_map.items(): |
|
|
if key in logs: |
|
|
val = logs[key] |
|
|
if isinstance(val, (float, int)): |
|
|
val_str = f"{val:.4f}" if val > 1e-4 else f"{val:.2e}" |
|
|
else: |
|
|
val_str = str(val) |
|
|
|
|
|
log_parts.append(f"{label}: {val_str}") |
|
|
|
|
|
log_payload = logs.copy() |
|
|
log_payload['step'] = state.global_step |
|
|
|
|
|
self.log_queue.put((" | ".join(log_parts), log_payload)) |
|
|
|
|
|
class FunctionGemmaEngine: |
|
|
def __init__(self, config: AppConfig): |
|
|
self.config = config |
|
|
|
|
|
self.session_id = str(uuid.uuid4())[:8] |
|
|
self.output_dir = self.config.ARTIFACTS_DIR.joinpath(f"session_{self.session_id}") |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.loaded_model_name = None |
|
|
self.imported_dataset = [] |
|
|
self.stop_event = threading.Event() |
|
|
self.current_tools = DEFAULT_TOOLS |
|
|
self.has_model_tuned = False |
|
|
|
|
|
authenticate_hf(self.config.HF_TOKEN) |
|
|
try: |
|
|
self.refresh_model() |
|
|
except Exception as e: |
|
|
print(f"Initial load warning: {e}") |
|
|
|
|
|
|
|
|
def get_tools_json(self) -> str: |
|
|
return json.dumps(self.current_tools, indent=2) |
|
|
|
|
|
def update_tools(self, json_str: str) -> str: |
|
|
try: |
|
|
new_tools = json.loads(json_str) |
|
|
if not isinstance(new_tools, list): |
|
|
return "Error: Schema must be a list of tool definitions." |
|
|
self.current_tools = new_tools |
|
|
return "β
Tool Schema Updated successfully." |
|
|
except json.JSONDecodeError as e: |
|
|
return f"β JSON Error: {e}" |
|
|
except Exception as e: |
|
|
return f"β Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
def _load_model_weights(self): |
|
|
print(f"[{self.session_id}] Loading model: {self.config.MODEL_NAME}...") |
|
|
self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME) |
|
|
self.loaded_model_name = self.config.MODEL_NAME |
|
|
|
|
|
def refresh_model(self) -> str: |
|
|
self.has_model_tuned = False |
|
|
try: |
|
|
self._load_model_weights() |
|
|
return f"Model loaded: {self.loaded_model_name}\nData cleared.\nReady (Session {self.session_id})." |
|
|
except Exception as e: |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.loaded_model_name = None |
|
|
return f"CRITICAL ERROR: Model failed to load. {e}" |
|
|
|
|
|
def load_csv(self, file_path: str) -> str: |
|
|
try: |
|
|
new_data = parse_csv_dataset(file_path) |
|
|
if not new_data: |
|
|
return "Error: File empty or format invalid." |
|
|
self.imported_dataset = new_data |
|
|
return f"Successfully imported {len(new_data)} samples." |
|
|
except Exception as e: |
|
|
return f"Import failed: {e}" |
|
|
|
|
|
def trigger_stop(self): |
|
|
self.stop_event.set() |
|
|
|
|
|
def _ensure_model_consistency(self) -> Generator[str, None, bool]: |
|
|
"""Checks if the requested model matches the loaded one. Reloads if necessary.""" |
|
|
if self.config.MODEL_NAME != self.loaded_model_name: |
|
|
yield f"π Model changed. Switching from '{self.loaded_model_name}' to '{self.config.MODEL_NAME}'...\n" |
|
|
try: |
|
|
self._load_model_weights() |
|
|
yield "β
Model reloaded successfully.\n" |
|
|
return True |
|
|
except Exception as e: |
|
|
yield f"β Failed to load model '{self.config.MODEL_NAME}': {e}\n" |
|
|
return False |
|
|
if self.model is None: |
|
|
yield "β Error: No model loaded.\n" |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def run_evaluation(self, test_size: float, shuffle_data: bool) -> Generator[str, None, None]: |
|
|
self.stop_event.clear() |
|
|
output_buffer = "" |
|
|
|
|
|
try: |
|
|
|
|
|
gen = self._ensure_model_consistency() |
|
|
try: |
|
|
while True: |
|
|
msg = next(gen) |
|
|
output_buffer += msg |
|
|
yield output_buffer |
|
|
except StopIteration as e: |
|
|
if not e.value: return |
|
|
|
|
|
|
|
|
output_buffer += f"β³ Preparing Dataset for Eval (Test Split: {test_size})...\n" |
|
|
yield output_buffer |
|
|
|
|
|
dataset, log = self._prepare_dataset() |
|
|
output_buffer += log |
|
|
yield output_buffer |
|
|
|
|
|
if not dataset: |
|
|
output_buffer += "β Dataset creation failed.\n" |
|
|
yield output_buffer |
|
|
return |
|
|
|
|
|
if len(dataset) > 1: |
|
|
dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data) |
|
|
else: |
|
|
dataset = {"train": dataset, "test": dataset} |
|
|
|
|
|
|
|
|
output_buffer += "\nπ Evaluating Model Success Rate on Test Split...\n" |
|
|
yield output_buffer |
|
|
|
|
|
for update in self._evaluate_model(dataset["test"]): |
|
|
yield f"{output_buffer}{update}" |
|
|
if self.stop_event.is_set(): |
|
|
yield f"{output_buffer}{update}\n\nπ Evaluation interrupted by user." |
|
|
break |
|
|
finally: |
|
|
self.stop_event.set() |
|
|
|
|
|
|
|
|
|
|
|
def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[Tuple[str, Any], None, None]: |
|
|
self.stop_event.clear() |
|
|
output_buffer = "" |
|
|
last_plot = None |
|
|
|
|
|
try: |
|
|
|
|
|
gen = self._ensure_model_consistency() |
|
|
try: |
|
|
while True: |
|
|
msg = next(gen) |
|
|
output_buffer += f"{msg}" |
|
|
yield output_buffer, None |
|
|
except StopIteration as e: |
|
|
if not e.value: return |
|
|
|
|
|
output_buffer += f"β³ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n" |
|
|
yield output_buffer, None |
|
|
|
|
|
dataset, log = self._prepare_dataset() |
|
|
if not dataset: |
|
|
yield "Dataset creation failed.", None |
|
|
return |
|
|
|
|
|
output_buffer += log |
|
|
yield output_buffer, None |
|
|
|
|
|
if len(dataset) > 1: |
|
|
dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data) |
|
|
else: |
|
|
dataset = {"train": dataset, "test": dataset} |
|
|
|
|
|
|
|
|
output_buffer += f"\nπ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n" |
|
|
yield output_buffer, None |
|
|
|
|
|
log_queue = queue.Queue() |
|
|
training_error = None |
|
|
running_history = [] |
|
|
|
|
|
def train_wrapper(): |
|
|
nonlocal training_error |
|
|
try: |
|
|
self._execute_trainer(dataset, log_queue, epochs, learning_rate) |
|
|
except Exception as e: |
|
|
training_error = e |
|
|
|
|
|
train_thread = threading.Thread(target=train_wrapper) |
|
|
train_thread.start() |
|
|
|
|
|
while train_thread.is_alive(): |
|
|
while not log_queue.empty(): |
|
|
payload = log_queue.get() |
|
|
if isinstance(payload, tuple): |
|
|
msg, log_data = payload |
|
|
output_buffer += f"{msg}\n" |
|
|
running_history.append(log_data) |
|
|
try: |
|
|
last_plot = self._generate_loss_plot(running_history) |
|
|
yield output_buffer, last_plot |
|
|
except Exception: |
|
|
yield output_buffer, last_plot |
|
|
else: |
|
|
output_buffer += f"{payload}\n" |
|
|
yield output_buffer, last_plot |
|
|
|
|
|
if self.stop_event.is_set(): |
|
|
yield f"{output_buffer}π Stop signal sent. Waiting for trainer to wrap up...\n", last_plot |
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
train_thread.join() |
|
|
|
|
|
self.has_model_tuned = True |
|
|
|
|
|
while not log_queue.empty(): |
|
|
payload = log_queue.get() |
|
|
if isinstance(payload, tuple): |
|
|
msg, log_data = payload |
|
|
output_buffer += f"{msg}\n" |
|
|
running_history.append(log_data) |
|
|
last_plot = self._generate_loss_plot(running_history) |
|
|
else: |
|
|
output_buffer += f"{payload}\n" |
|
|
yield output_buffer, last_plot |
|
|
|
|
|
if training_error: |
|
|
output_buffer += f"β Error during training: {training_error}\n" |
|
|
yield output_buffer, last_plot |
|
|
return |
|
|
|
|
|
if self.stop_event.is_set(): |
|
|
output_buffer += "π Training manually stopped.\n" |
|
|
yield output_buffer, last_plot |
|
|
return |
|
|
|
|
|
output_buffer += "β
Training finished.\n" |
|
|
yield output_buffer, last_plot |
|
|
|
|
|
finally: |
|
|
self.stop_event.set() |
|
|
|
|
|
def _prepare_dataset(self): |
|
|
formatting_fn = partial(create_conversation_format, tools_list=self.current_tools) |
|
|
|
|
|
if not self.imported_dataset: |
|
|
ds = load_dataset(self.config.DEFAULT_DATASET, split="train").map(formatting_fn) |
|
|
log = f" `-> using default dataset (size:{len(ds)})\n" |
|
|
else: |
|
|
dataset_as_dicts = [{ |
|
|
"user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]} |
|
|
for row in self.imported_dataset |
|
|
] |
|
|
ds = Dataset.from_list(dataset_as_dicts).map(formatting_fn) |
|
|
log = f" `-> using custom dataset (size:{len(ds)})\n" |
|
|
return ds, log |
|
|
|
|
|
def _execute_trainer(self, dataset, log_queue: queue.Queue, epochs: int, learning_rate: float) -> List[Dict]: |
|
|
torch_dtype = self.model.dtype |
|
|
args = SFTConfig( |
|
|
output_dir=str(self.output_dir), |
|
|
max_length=512, |
|
|
packing=False, |
|
|
num_train_epochs=epochs, |
|
|
per_device_train_batch_size=4, |
|
|
logging_steps=1, |
|
|
save_strategy="no", |
|
|
eval_strategy="epoch", |
|
|
learning_rate=learning_rate, |
|
|
fp16=(torch_dtype == torch.float16), |
|
|
bf16=(torch_dtype == torch.bfloat16), |
|
|
report_to="none", |
|
|
dataset_kwargs={"add_special_tokens": False, "append_concat_token": True} |
|
|
) |
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=self.model, |
|
|
args=args, |
|
|
train_dataset=dataset['train'], |
|
|
eval_dataset=dataset['test'], |
|
|
processing_class=self.tokenizer, |
|
|
callbacks=[ |
|
|
AbortCallback(self.stop_event), |
|
|
LogStreamingCallback(log_queue) |
|
|
] |
|
|
) |
|
|
trainer.train() |
|
|
trainer.save_model() |
|
|
return trainer.state.log_history |
|
|
|
|
|
def _generate_loss_plot(self, history: list): |
|
|
if not history: return None |
|
|
plt.close('all') |
|
|
|
|
|
train_steps = [x['step'] for x in history if 'loss' in x] |
|
|
train_loss = [x['loss'] for x in history if 'loss' in x] |
|
|
eval_steps = [x['step'] for x in history if 'eval_loss' in x] |
|
|
eval_loss = [x['eval_loss'] for x in history if 'eval_loss' in x] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 5)) |
|
|
if train_steps: |
|
|
ax.plot(train_steps, train_loss, label='Training Loss', linestyle='-', marker=None) |
|
|
if eval_steps: |
|
|
ax.plot(eval_steps, eval_loss, label='Validation Loss', linestyle='--', marker='o') |
|
|
|
|
|
ax.set_xlabel("Steps") |
|
|
ax.set_ylabel("Loss") |
|
|
ax.set_title("Training & Validation Loss") |
|
|
ax.legend() |
|
|
ax.grid(True, linestyle=':', alpha=0.6) |
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
def _evaluate_model(self, test_dataset) -> Generator[str, None, None]: |
|
|
results = [] |
|
|
success_count = 0 |
|
|
for idx, item in enumerate(test_dataset): |
|
|
messages = item["messages"][:2] |
|
|
try: |
|
|
inputs = self.tokenizer.apply_chat_template( |
|
|
messages, tools=self.current_tools, add_generation_prompt=True, return_dict=True, return_tensors="pt" |
|
|
) |
|
|
device = self.model.device |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
out = self.model.generate( |
|
|
**inputs, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
max_new_tokens=128 |
|
|
) |
|
|
output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) |
|
|
log_entry = f"{idx+1}. Prompt: {messages[1]['content']}\n Output: {output[:100]}..." |
|
|
expected_tool = item['messages'][2]['tool_calls'][0]['function']['name'] |
|
|
if expected_tool in output: |
|
|
log_entry += "\n -> β
Correct Tool" |
|
|
success_count += 1 |
|
|
else: |
|
|
log_entry += f"\n -> β Wrong Tool (Expected: {expected_tool})" |
|
|
results.append(log_entry) |
|
|
yield "\n".join(results) + f"\n\nRunning Success Rate: {success_count}/{idx+1}" |
|
|
except Exception as e: |
|
|
yield f"Error during inference: {e}" |
|
|
|
|
|
def get_zip_path(self) -> Optional[str]: |
|
|
if not self.output_dir.exists(): return None |
|
|
base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{self.session_id}")) |
|
|
return zip_directory(str(self.output_dir), base_name) |
|
|
|
|
|
def upload_model_to_hub(self, repo_name: str, oauth_token: str) -> str: |
|
|
"""Uploads the trained model to Hugging Face Hub.""" |
|
|
if not self.output_dir.exists() or not any(self.output_dir.iterdir()): |
|
|
return "β No trained model found in current session. Run training first." |
|
|
|
|
|
try: |
|
|
api = HfApi(token=oauth_token) |
|
|
|
|
|
|
|
|
user_info = api.whoami() |
|
|
username = user_info['name'] |
|
|
|
|
|
|
|
|
repo_id = f"{username}/{repo_name}" |
|
|
print(f"Preparing to upload to: {repo_id}") |
|
|
|
|
|
|
|
|
api.create_repo(repo_id=repo_id, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"Uploading to {repo_id}...") |
|
|
repo_url = api.upload_folder( |
|
|
folder_path=str(self.output_dir), |
|
|
repo_id=repo_id, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
info = model_info( |
|
|
repo_id=repo_id, |
|
|
token=oauth_token |
|
|
) |
|
|
tags = ["functiongemma", "functiongemma-tuning-lab"] |
|
|
if info.card_data: |
|
|
tags = info.card_data.tags |
|
|
tags.append("functiongemma-tuning-lab") |
|
|
|
|
|
metadata_update(repo_id, {"tags": tags}, overwrite=True, token=oauth_token) |
|
|
|
|
|
return f"β
Success! Model uploaded to: {repo_url}" |
|
|
except Exception as e: |
|
|
return f"β Upload failed: {str(e)}" |