liamvbetts commited on
Commit
e673826
1 Parent(s): d2894f7

random button

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -1,24 +1,37 @@
1
  import gradio as gr
2
-
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
4
 
5
  tokenizer = AutoTokenizer.from_pretrained("liamvbetts/bart-large-cnn-v4")
6
  model = AutoModelForSeq2SeqLM.from_pretrained("liamvbetts/bart-large-cnn-v4")
7
 
 
 
8
  def summarize(article):
9
  inputs = tokenizer(article, return_tensors="pt").input_ids
10
  outputs = model.generate(inputs, max_new_tokens=128, do_sample=False)
11
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
12
  return summary
13
 
 
 
 
 
 
 
14
  # Create Gradio interface
15
  input_text = gr.Textbox(lines=10, label="Input Text")
16
  output_text = gr.Textbox(label="Summary")
 
 
 
 
17
 
18
  gr.Interface(
19
  fn=summarize,
20
- inputs=input_text,
21
  outputs=output_text,
22
  title="News Summary App",
23
- description="Enter a news text and get its summary."
24
  ).launch()
 
1
  import gradio as gr
2
+ import random
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ from datasets import load_dataset
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("liamvbetts/bart-large-cnn-v4")
7
  model = AutoModelForSeq2SeqLM.from_pretrained("liamvbetts/bart-large-cnn-v4")
8
 
9
+ dataset = load_dataset("cnn_dailymail", "3.0.0")
10
+
11
  def summarize(article):
12
  inputs = tokenizer(article, return_tensors="pt").input_ids
13
  outputs = model.generate(inputs, max_new_tokens=128, do_sample=False)
14
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
  return summary
16
 
17
+ def get_random_article():
18
+ random.seed()
19
+ val_example = dataset["validation"].shuffle().select(range(1))
20
+ val_article = val_example['article'][0][:512]
21
+ return val_article
22
+
23
  # Create Gradio interface
24
  input_text = gr.Textbox(lines=10, label="Input Text")
25
  output_text = gr.Textbox(label="Summary")
26
+ random_article_button = gr.Button("Load Random Article")
27
+
28
+ def update_input_text():
29
+ return get_random_article()
30
 
31
  gr.Interface(
32
  fn=summarize,
33
+ inputs=[input_text, gr.components.Button("Load Random Article").click(update_input_text, [], input_text)],
34
  outputs=output_text,
35
  title="News Summary App",
36
+ description="Enter a news text and get its summary, or load a random article from the validation set."
37
  ).launch()