jjuarez commited on
Commit
e86bff1
1 Parent(s): 4fce2e0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ import evaluate
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
5
+ import numpy as np
6
+ import nltk
7
+
8
+ nltk.download("punkt")
9
+ raw_dataset = load_dataset("scientific_papers", "pubmed")
10
+ metric = evaluate.load("rouge")
11
+ model_checkpoint = "t5-base"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
13
+
14
+ if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
15
+ prefix = "summarize: "
16
+ else:
17
+ prefix = ""
18
+
19
+ # preprocessing function
20
+ max_input_length = 256
21
+ max_target_length = 64
22
+ def preprocess_function(examples):
23
+ inputs = [prefix + doc for doc in examples["article"]]
24
+ model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
25
+
26
+ # Setup the tokenizer for targets
27
+ # with tokenizer.as_target_tokenizer():
28
+ labels = tokenizer(text_target=examples["abstract"], max_length=max_target_length, truncation=True)
29
+
30
+ model_inputs["labels"] = labels["input_ids"]
31
+ return model_inputs
32
+
33
+ for split in ["train", "validation", "test"]:
34
+ raw_dataset[split] = raw_dataset[split].select([n for n in np.random.randint(0, len(raw_dataset[split]) - 1, 200)])
35
+
36
+ tokenized_dataset = raw_dataset.map(preprocess_function, batched=True)
37
+
38
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
39
+
40
+ batch_size = 4
41
+
42
+ args = Seq2SeqTrainingArguments(
43
+ f"{model_checkpoint}-scientific_papers",
44
+ evaluation_strategy="epoch",
45
+ learning_rate=3e-5,
46
+ per_device_train_batch_size=batch_size,
47
+ per_device_eval_batch_size=batch_size,
48
+ weight_decay=0.01,
49
+ save_total_limit=3,
50
+ num_train_epochs=0.5,
51
+ predict_with_generate=True,
52
+ # fp16=True,
53
+ push_to_hub=False,
54
+ gradient_accumulation_steps=2
55
+ )
56
+
57
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
58
+
59
+ # computing metrics from the predictions
60
+ def compute_metrics(eval_pred):
61
+ predictions, labels = eval_pred
62
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
63
+ # Replace -100 in the labels as we can't decode them.
64
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
65
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
66
+ # Rouge expects a newline after each sentence
67
+ decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
68
+ decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
69
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
70
+ # Extract a few results
71
+ result = {key: value * 100 for key, value in result.items()}
72
+ # Add mean generated length
73
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
74
+ result["gen_len"] = np.mean(prediction_lens)
75
+ return {k: round(v, 4) for k, v in result.items()}
76
+
77
+
78
+ # Define the training and evaluation datasets
79
+ train_dataset = tokenized_dataset["train"]
80
+ eval_dataset = tokenized_dataset["validation"]
81
+
82
+ # Create the trainer object
83
+ trainer = Seq2SeqTrainer(
84
+ model=model,
85
+ args=args,
86
+ train_dataset=train_dataset,
87
+ eval_dataset=eval_dataset,
88
+ data_collator=data_collator,
89
+ compute_metrics=compute_metrics,
90
+ )
91
+
92
+ # Train the model
93
+ trainer.train()
94
+
95
+ # Define the input and output interface of the app
96
+ def summarizer(input_text):
97
+ inputs = [prefix + input_text]
98
+ model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
99
+ summary_ids = model.generate(
100
+ input_ids=model_inputs["input_ids"],
101
+ attention_mask=model_inputs["attention_mask"],
102
+ num_beams=4,
103
+ length_penalty=2.0,
104
+ max_length=max_target_length + 2, # +2 from original because we start at step=1 and stop before max_length
105
+ repetition_penalty=2.0,
106
+ early_stopping=True,
107
+ use_cache=True
108
+ )
109
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
110
+ return summary
111
+
112
+ # Interface creation and launching
113
+ iface = gr.Interface(
114
+ fn=summarizer,
115
+ inputs=gr.inputs.Textbox(label="Input Text"),
116
+ outputs=gr.outputs.Textbox(label="Summary"),
117
+ title="Scientific Paper Summarizer",
118
+ description="Summarizes scientific papers using a fine-tuned T5 model",
119
+ theme="gray"
120
+ )
121
+ iface.launch()
122
+