Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from transformers import Trainer, TrainingArguments, GenerationConfig | |
def load_model(model_name="facebook/bart-large-cnn"): | |
""" | |
Load a pre-trained summarization model | |
Options: facebook/bart-large-cnn, google/pegasus-xsum, sshleifer/distilbart-cnn-12-6 | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
return model, tokenizer | |
def fine_tune_model(model, tokenizer, dataset, output_dir="./summarization_model"): | |
"""Fine-tune model on prepared dataset""" | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
gradient_accumulation_steps=4, | |
learning_rate=5e-5, | |
num_train_epochs=6, | |
save_strategy="epoch", | |
eval_strategy="epoch", | |
load_best_model_at_end=True, | |
report_to="none", | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["validation"], | |
tokenizer=tokenizer, | |
) | |
trainer.train() | |
tokenizer.save_pretrained(output_dir) | |
model.save_pretrained(output_dir) | |
return model | |
def generate_stylized_summary(text, model, tokenizer, style="formal", max_length=150): | |
"""Generate a summary in the specified style""" | |
# Prepend style token to input | |
styled_input = f"[{style.upper()}] {text}" | |
inputs = tokenizer( | |
styled_input, return_tensors="pt", max_length=1024, truncation=True | |
) | |
generation_config = GenerationConfig( | |
max_length=max_length, | |
min_length=56, | |
early_stopping=True, | |
num_beams=4, | |
length_penalty=2.0, | |
no_repeat_ngram_size=3, | |
forced_bos_token_id=0, | |
) | |
summary_ids = model.generate( | |
inputs["input_ids"], generation_config=generation_config | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |