# -*- coding: utf-8 -*- """Untitled0.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/1aMkctyYgdHD61sv7-bJHFN1B5taCv6c2 """ import gradio as gr from datasets import load_dataset import evaluate from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer import numpy as np import nltk nltk.download("punkt") raw_dataset = load_dataset("scientific_papers", "pubmed") metric = evaluate.load("rouge") model_checkpoint = "t5-small" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]: prefix = "summarize: " else: prefix = "" # preprocessing function max_input_length = 512 max_target_length = 128 def preprocess_function(examples): inputs = [prefix + doc for doc in examples["article"]] model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) # Setup the tokenizer for targets # with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=examples["abstract"], max_length=max_target_length, truncation=True) model_inputs["labels"] = labels["input_ids"] return model_inputs for split in ["train", "validation", "test"]: raw_dataset[split] = raw_dataset[split].select([n for n in np.random.randint(0, len(raw_dataset[split]) - 1, 1_000)]) tokenized_dataset = raw_dataset.map(preprocess_function, batched=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) batch_size = 8 args = Seq2SeqTrainingArguments( f"{model_checkpoint}-scientific_papers", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, weight_decay=0.01, save_total_limit=3, num_train_epochs=1, predict_with_generate=True, # fp16=True, push_to_hub=False, gradient_accumulation_steps=2 ) data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) # computing metrics from the predictions def compute_metrics(eval_pred): predictions, labels = eval_pred decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) # Replace -100 in the labels as we can't decode them. labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Rouge expects a newline after each sentence decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds] decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels] result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) # Extract a few results result = {key: value * 100 for key, value in result.items()} # Add mean generated length prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions] result["gen_len"] = np.mean(prediction_lens) return {k: round(v, 4) for k, v in result.items()} trainer = Seq2SeqTrainer( model, args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["validation"], data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics ) trainer.train() # Define the input and output interface of the app import gradio as gr def summarizer(input_text): inputs = [prefix + input_text] model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt") summary_ids = model.generate( input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], num_beams=4, length_penalty=2.0, max_length=max_target_length + 2, # +2 from original because we start at step=1 and stop before max_length repetition_penalty=2.0, early_stopping=True, use_cache=True ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary # Interface creation and launching iface = gr.Interface( fn=summarizer, inputs=gr.inputs.Textbox(label="Input Text"), outputs=gr.outputs.Textbox(label="Summary"), title="Scientific Paper Summarizer", description="Summarizes scientific papers using a fine-tuned T5 model", theme="gray" ) iface.launch()