amine-araich's picture
Update model.py
36f1333 verified
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import Trainer, TrainingArguments, GenerationConfig
def load_model(model_name="facebook/bart-large-cnn"):
"""
Load a pre-trained summarization model
Options: facebook/bart-large-cnn, google/pegasus-xsum, sshleifer/distilbart-cnn-12-6
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return model, tokenizer
def fine_tune_model(model, tokenizer, dataset, output_dir="./summarization_model"):
"""Fine-tune model on prepared dataset"""
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-5,
num_train_epochs=6,
save_strategy="epoch",
eval_strategy="epoch",
load_best_model_at_end=True,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
)
trainer.train()
tokenizer.save_pretrained(output_dir)
model.save_pretrained(output_dir)
return model
def generate_stylized_summary(text, model, tokenizer, style="formal", max_length=150):
"""Generate a summary in the specified style"""
# Prepend style token to input
styled_input = f"[{style.upper()}] {text}"
inputs = tokenizer(
styled_input, return_tensors="pt", max_length=1024, truncation=True
)
generation_config = GenerationConfig(
max_length=max_length,
min_length=56,
early_stopping=True,
num_beams=4,
length_penalty=2.0,
no_repeat_ngram_size=3,
forced_bos_token_id=0,
)
summary_ids = model.generate(
inputs["input_ids"], generation_config=generation_config
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary