Spaces:
Sleeping
Sleeping
| import torch | |
| import wandb | |
| import yaml | |
| from transformers import Trainer, TrainingArguments | |
| from data.datasets import load_and_tokenize_data | |
| from models.full_finetune_model import get_full_finetune_model | |
| from models.student_model import get_student_model | |
| # Charger la configuration | |
| with open('config/config.yaml', 'r') as f: | |
| config = yaml.safe_load(f) | |
| # Initialiser wandb | |
| wandb.init(project=config['wandb']['project'], entity=config['wandb']['entity']) | |
| # Charger les donn�es | |
| train_dataset, test_dataset = load_and_tokenize_data(config) | |
| # Charger le mod�le teacher et le mod�le student | |
| teacher_model = get_full_finetune_model() | |
| student_model = get_student_model(config) | |
| # D�finir les arguments de formation pour la distillation | |
| training_args = TrainingArguments( | |
| output_dir='./results_student', | |
| num_train_epochs=config['training']['num_epochs'], | |
| per_device_train_batch_size=config['training']['batch_size'], | |
| per_device_eval_batch_size=config['training']['batch_size'], | |
| evaluation_strategy='epoch', | |
| save_steps=10_000, | |
| save_total_limit=2, | |
| logging_dir='./logs', | |
| logging_steps=10, | |
| ) | |
| # D�finir le distillateur | |
| class DistillationTrainer(Trainer): | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| # Forward pass of teacher model | |
| with torch.no_grad(): | |
| teacher_outputs = teacher_model(**inputs) | |
| # Forward pass of student model | |
| student_outputs = model(**inputs) | |
| # Compute distillation loss | |
| loss = torch.nn.functional.kl_div( | |
| torch.nn.functional.log_softmax(student_outputs.logits, dim=-1), | |
| torch.nn.functional.softmax(teacher_outputs.logits, dim=-1), | |
| reduction='batchmean' | |
| ) | |
| return (loss, student_outputs) if return_outputs else loss | |
| # Cr�er le Trainer pour la distillation | |
| trainer = DistillationTrainer( | |
| model=student_model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=test_dataset, | |
| ) | |
| # Mesurer les ressources et entra�ner le mod�le student | |
| measure_resources(trainer, "Distillation") | |