carrie commited on
Commit
c8b4824
1 Parent(s): 36895c4

update model processing code

Browse files
Files changed (2) hide show
  1. app.py +52 -4
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,14 +1,62 @@
1
  import os
2
  import gradio as gr
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
-
 
 
5
 
6
  model = AutoModelForSeq2SeqLM.from_pretrained("fangyuan/lfqa_role_classification")
7
  tokenizer = AutoTokenizer.from_pretrained("fangyuan/lfqa_role_classification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
9
 
10
- def predict(input):
11
- pass
12
 
13
 
14
  gr.Interface(
@@ -18,7 +66,7 @@ gr.Interface(
18
  gr.inputs.Textbox(lines=1, label="Answer:"),
19
  ],
20
  outputs=[
21
- gr.outputs.Textbox(label="Predicted functional role for each sentence"),
22
  ],
23
  theme="peach",
24
  title="Discourse structure of long-form answer",
 
1
  import os
2
  import gradio as gr
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ import stanza
5
+ import re
6
+ stanza.download('en', processors='tokenize')
7
 
8
  model = AutoModelForSeq2SeqLM.from_pretrained("fangyuan/lfqa_role_classification")
9
  tokenizer = AutoTokenizer.from_pretrained("fangyuan/lfqa_role_classification")
10
+ en_nlp = stanza.Pipeline('en', processors='tokenize')
11
+
12
+ def get_ans_sentence_with_stanza(answer_paragraph, pipeline,
13
+ is_offset=False):
14
+ '''sentence segmentation with stanza'''
15
+ answer_paragraph_processed = pipeline(answer_paragraph)
16
+ sentences = []
17
+ for sent in answer_paragraph_processed.sentences:
18
+ if is_offset:
19
+ sentences.append((sent.tokens[0].start_char, sent.tokens[-1].end_char))
20
+ else:
21
+ sentence = answer_paragraph[sent.tokens[0].start_char:sent.tokens[-1].end_char + 1]
22
+ sentences.append(sentence.strip())
23
+ return sentences
24
+
25
+
26
+ def create_input_to_t5(question, answer):
27
+ input_line = [question]
28
+ answer_paragraph = get_ans_sentence_with_stanza(answer, en_nlp)
29
+ for idx, answer_sent in enumerate(answer_paragraph):
30
+ sep_token = '[{}]'.format(idx)
31
+ input_line.append(sep_token)
32
+ input_line.append(answer_sent)
33
+ return ' '.join(input_line)
34
+
35
+ def process_t5_output(input_txt, output_txt):
36
+ pred_roles = []
37
+ answer_sentence = re.split('\[\d+\] ', input_txt)
38
+ question = answer_sentence[0].strip()
39
+ answer_sentence = answer_sentence[1:]
40
+ sentence_idx = re.findall('\[\d+\]', input_txt)
41
+ idx_to_sentence = zip(sentence_idx, answer_sentence)
42
+ pred_role = re.split('\[\d+\] ', output_txt)[1:]
43
+ pred_idx = re.findall('\[\d+\]', output_txt)
44
+ idx_to_role = {
45
+ idx: role.strip() for (idx, role) in zip(pred_idx, pred_role)
46
+ }
47
+ for _, (idx, sentence) in enumerate(idx_to_sentence):
48
+ pred_roles.append(' ' if idx not in idx_to_role else idx_to_role[idx])
49
+ return '\n'.join(pred_roles)
50
+
51
+
52
 
53
+ def predict(question, answer):
54
+ input_txt = create_input_to_t5(question, answer)
55
+ input_ids = tokenizer(input_txt, return_tensors='pt').input_ids
56
+ outputs = model.generate(input_ids, max_length=512)
57
+ output_txt = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
58
+ return process_t5_output(input_txt, output_txt)
59
 
 
 
60
 
61
 
62
  gr.Interface(
 
66
  gr.inputs.Textbox(lines=1, label="Answer:"),
67
  ],
68
  outputs=[
69
+ gr.outputs.Textbox(label="Predicted sentence-level functional roles"),
70
  ],
71
  theme="peach",
72
  title="Discourse structure of long-form answer",
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio
2
  transformers
3
- torch
 
 
1
  gradio
2
  transformers
3
+ torch
4
+ stanza