ttj commited on
Commit
319c415
1 Parent(s): f2cbb60

Format code with black

Browse files
Files changed (1) hide show
  1. app.py +38 -27
app.py CHANGED
@@ -1,12 +1,15 @@
1
  from transformers import T5ForConditionalGeneration, T5TokenizerFast, pipeline
2
- from transformers.models.f_t5.modeling_t5 import \
3
- T5ForConditionalGeneration as FT5ForConditionalGeneration
4
- from transformers.models.f_t5.tokenization_t5_fast import \
5
- T5TokenizerFast as FT5TokenizerFast
 
 
6
 
7
  import json
8
- with open('examples.json') as f:
9
- examples = json.load(f)['article']
 
10
 
11
  model_name = "flax-community/ft5-cnn-dm"
12
  ft5_model = FT5ForConditionalGeneration.from_pretrained(model_name)
@@ -15,39 +18,47 @@ ft5_summarizer = pipeline(
15
  "summarization", model=ft5_model, tokenizer=ft5_tokenizer, framework="pt"
16
  )
17
 
18
- #model_name = 'flax-community/t5-base-cnn-dm'
19
- #t5_model = T5ForConditionalGeneration.from_pretrained(model_name)
20
- #t5_tokenizer = T5TokenizerFast.from_pretrained(model_name)
21
- #predict_t5 = get_predict(t5_model, t5_tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def fn(text, do_sample, min_length, max_length,temperature, top_p):
24
- out = ft5_summarizer(text, do_sample=do_sample, min_length=min_length,
25
- max_length=max_length, temperature=temperature, top_p=top_p,
26
- truncation=True)
27
- return out[0]['summary_text']
28
  import gradio as gr
29
 
30
  interface = gr.Interface(
31
  fn,
32
  inputs=[
33
- gr.inputs.Textbox(lines=10, label='text'),
34
- gr.inputs.Checkbox(label='do_sample'),
35
- gr.inputs.Slider(1, 128, step=1, default=64, label='min_length'),
36
- gr.inputs.Slider(1, 128, step=1, default=64, label='max_length'),
37
- gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label='temperature'),
38
- gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label='top_p'),
39
- ],
40
  outputs=gr.outputs.Textbox(),
41
- #server_port=8080,
42
- #server_name='0.0.0.0',
43
  examples=[[ex] for ex in examples],
44
- title='F-T5 News Summarization',
45
  description="""
46
  F-T5 is a hybrid encoder-decoder model based on T5 and FNet.
47
  The model architecture is based on T5, except the encoder self attention is replaced by fourier transform as in FNet.
48
  The model is pre-trained on openwebtext, fine-tuned on CNN/DM.
49
- """
50
-
51
  )
52
 
53
  interface.launch()
 
1
  from transformers import T5ForConditionalGeneration, T5TokenizerFast, pipeline
2
+ from transformers.models.f_t5.modeling_t5 import (
3
+ T5ForConditionalGeneration as FT5ForConditionalGeneration,
4
+ )
5
+ from transformers.models.f_t5.tokenization_t5_fast import (
6
+ T5TokenizerFast as FT5TokenizerFast,
7
+ )
8
 
9
  import json
10
+
11
+ with open("examples.json") as f:
12
+ examples = json.load(f)["article"]
13
 
14
  model_name = "flax-community/ft5-cnn-dm"
15
  ft5_model = FT5ForConditionalGeneration.from_pretrained(model_name)
 
18
  "summarization", model=ft5_model, tokenizer=ft5_tokenizer, framework="pt"
19
  )
20
 
21
+ # model_name = 'flax-community/t5-base-cnn-dm'
22
+ # t5_model = T5ForConditionalGeneration.from_pretrained(model_name)
23
+ # t5_tokenizer = T5TokenizerFast.from_pretrained(model_name)
24
+ # predict_t5 = get_predict(t5_model, t5_tokenizer)
25
+
26
+
27
+ def fn(text, do_sample, min_length, max_length, temperature, top_p):
28
+ out = ft5_summarizer(
29
+ text,
30
+ do_sample=do_sample,
31
+ min_length=min_length,
32
+ max_length=max_length,
33
+ temperature=temperature,
34
+ top_p=top_p,
35
+ truncation=True,
36
+ )
37
+ return out[0]["summary_text"]
38
+
39
 
 
 
 
 
 
40
  import gradio as gr
41
 
42
  interface = gr.Interface(
43
  fn,
44
  inputs=[
45
+ gr.inputs.Textbox(lines=10, label="text"),
46
+ gr.inputs.Checkbox(label="do_sample"),
47
+ gr.inputs.Slider(1, 128, step=1, default=64, label="min_length"),
48
+ gr.inputs.Slider(1, 128, step=1, default=64, label="max_length"),
49
+ gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label="temperature"),
50
+ gr.inputs.Slider(0.0, 1.0, step=0.1, default=1, label="top_p"),
51
+ ],
52
  outputs=gr.outputs.Textbox(),
53
+ # server_port=8080,
54
+ # server_name='0.0.0.0',
55
  examples=[[ex] for ex in examples],
56
+ title="F-T5 News Summarization",
57
  description="""
58
  F-T5 is a hybrid encoder-decoder model based on T5 and FNet.
59
  The model architecture is based on T5, except the encoder self attention is replaced by fourier transform as in FNet.
60
  The model is pre-trained on openwebtext, fine-tuned on CNN/DM.
61
+ """,
 
62
  )
63
 
64
  interface.launch()