ttj commited on
Commit
706179c
1 Parent(s): 319c415

Add T5, do_sample by default

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -18,14 +18,16 @@ ft5_summarizer = pipeline(
18
  "summarization", model=ft5_model, tokenizer=ft5_tokenizer, framework="pt"
19
  )
20
 
21
- # model_name = 'flax-community/t5-base-cnn-dm'
22
- # t5_model = T5ForConditionalGeneration.from_pretrained(model_name)
23
- # t5_tokenizer = T5TokenizerFast.from_pretrained(model_name)
24
- # predict_t5 = get_predict(t5_model, t5_tokenizer)
 
 
25
 
26
 
27
- def fn(text, do_sample, min_length, max_length, temperature, top_p):
28
- out = ft5_summarizer(
29
  text,
30
  do_sample=do_sample,
31
  min_length=min_length,
@@ -37,23 +39,28 @@ def fn(text, do_sample, min_length, max_length, temperature, top_p):
37
  return out[0]["summary_text"]
38
 
39
 
 
 
 
 
40
  import gradio as gr
41
 
42
  interface = gr.Interface(
43
  fn,
44
  inputs=[
45
- gr.inputs.Textbox(lines=10, label="text"),
46
- gr.inputs.Checkbox(label="do_sample"),
47
  gr.inputs.Slider(1, 128, step=1, default=64, label="min_length"),
48
  gr.inputs.Slider(1, 128, step=1, default=64, label="max_length"),
49
  gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label="temperature"),
50
  gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label="top_p"),
51
  ],
52
- outputs=gr.outputs.Textbox(),
53
- # server_port=8080,
54
- # server_name='0.0.0.0',
 
55
  examples=[[ex] for ex in examples],
56
- title="F-T5 News Summarization",
57
  description="""
58
  F-T5 is a hybrid encoder-decoder model based on T5 and FNet.
59
  The model architecture is based on T5, except the encoder self attention is replaced by fourier transform as in FNet.
 
18
  "summarization", model=ft5_model, tokenizer=ft5_tokenizer, framework="pt"
19
  )
20
 
21
+ model_name = "flax-community/t5-base-cnn-dm"
22
+ t5_model = T5ForConditionalGeneration.from_pretrained(model_name)
23
+ t5_tokenizer = T5TokenizerFast.from_pretrained(model_name)
24
+ t5_summarizer = pipeline(
25
+ "summarization", model=t5_model, tokenizer=t5_tokenizer, framework="pt"
26
+ )
27
 
28
 
29
+ def _fn(text, do_sample, min_length, max_length, temperature, top_p, summarizer):
30
+ out = summarizer(
31
  text,
32
  do_sample=do_sample,
33
  min_length=min_length,
 
39
  return out[0]["summary_text"]
40
 
41
 
42
+ def fn(*args):
43
+ return [_fn(*args, summarizer=s) for s in (t5_summarizer, ft5_summarizer)]
44
+
45
+
46
  import gradio as gr
47
 
48
  interface = gr.Interface(
49
  fn,
50
  inputs=[
51
+ gr.inputs.Textbox(lines=10, label="article"),
52
+ gr.inputs.Checkbox(label="do_sample", default=True),
53
  gr.inputs.Slider(1, 128, step=1, default=64, label="min_length"),
54
  gr.inputs.Slider(1, 128, step=1, default=64, label="max_length"),
55
  gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label="temperature"),
56
  gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label="top_p"),
57
  ],
58
+ outputs=[
59
+ gr.outputs.Textbox(label="summary by T5"),
60
+ gr.outputs.Textbox(label="summary by F-T5"),
61
+ ],
62
  examples=[[ex] for ex in examples],
63
+ title="F-T5 News Summarizer",
64
  description="""
65
  F-T5 is a hybrid encoder-decoder model based on T5 and FNet.
66
  The model architecture is based on T5, except the encoder self attention is replaced by fourier transform as in FNet.