philipp-zettl commited on
Commit
6e74145
1 Parent(s): 6b51ec6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import itertools
4
+ import pandas as pd
5
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
+
7
+
8
+ model_name = 'philipp-zettl/t5-small-long-qa'
9
+ qa_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+ model_name = 'philipp-zettl/t5-small-qg'
11
+ qg_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
+ tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')
13
+
14
+ # Move only the student model to GPU if available
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ qa_model = qa_model.to(device)
17
+ qg_model = qg_model.to(device)
18
+
19
+ max_questions = 1
20
+ max_answers = 1
21
+
22
+
23
+ def run_model(inputs, tokenizer, model, temperature=0.5, num_return_sequences=1):
24
+ all_outputs = []
25
+ for input_text in inputs:
26
+ model_inputs = tokenizer([input_text], max_length=512, padding=True, truncation=True)
27
+ input_ids = torch.tensor(model_inputs['input_ids']).to(device)
28
+ for sample in input_ids:
29
+ sample_outputs = []
30
+ with torch.no_grad():
31
+ sample_output = model.generate(
32
+ input_ids[:1],
33
+ max_length=85,
34
+ temperature=temperature,
35
+ do_sample=True,
36
+ num_return_sequences=num_return_sequences,
37
+ low_memory=True,
38
+ num_beams=max(2, num_return_sequences),
39
+ use_cache=True,
40
+ )
41
+ for i, sample_output in enumerate(sample_output):
42
+ sample_output = sample_output.unsqueeze(0)
43
+ sample_output = tokenizer.decode(sample_output[0], skip_special_tokens=True)
44
+ sample_outputs.append(sample_output)
45
+
46
+ all_outputs.append(sample_outputs)
47
+ return all_outputs
48
+
49
+
50
+ def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_qg=1, num_return_sequences_qa=1):
51
+ inputs = [
52
+ f'context: {content}'
53
+ ]
54
+ question = run_model(inputs, tokenizer, qg_model, temperature_qg, num_return_sequences_qg)
55
+
56
+ inputs = list(itertools.chain.from_iterable([
57
+ [f'question: {q} {inputs[0]}' for q in q_set] for q_set in question
58
+ ]))
59
+ answer = run_model(inputs, tokenizer, qa_model, temperature_qa, num_return_sequences_qa)
60
+
61
+ questions = list(itertools.chain.from_iterable(question))
62
+ answers = list(itertools.chain.from_iterable(answer))
63
+
64
+ results = []
65
+ for idx, ans in enumerate(answers):
66
+ results.append({'question': questions[idx % num_return_sequences_qg], 'answer': ans})
67
+ return results
68
+
69
+
70
+ def variable_outputs(k, max_elems=10):
71
+ k = int(k)
72
+ return [gr.Text(visible=True)] * k + [gr.Text(visible=False)] * (max(max_elems, 10)- k)
73
+
74
+
75
+ def set_outputs(content, max_elems=10):
76
+ c = eval(content)
77
+ print('received content: ', c)
78
+ return [gr.Text(value=t, visible=True) for t in c] + [gr.Text(visible=False)] * (max(max_elems, 10) - len(c))
79
+
80
+
81
+ def create_file_download(qnas):
82
+ with open('qnas.tsv', 'w') as f:
83
+ for idx, qna in qnas.iterrows():
84
+ f.write(qna['Question'] + '\t' + qna['Answer'])
85
+ if idx < len(qnas) - 1:
86
+ f.write('\n')
87
+ return 'qnas.tsv'
88
+
89
+
90
+ with gr.Blocks() as demo:
91
+ with gr.Row(equal_height=True):
92
+ with gr.Group("Content"):
93
+ content = gr.Textbox(label='Content', lines=15, placeholder='Enter text here', max_lines=10_000)
94
+ with gr.Group("Settings"):
95
+ temperature_qg = gr.Slider(label='Temperature QG', value=0.5, minimum=0, maximum=1, step=0.01)
96
+ temperature_qa = gr.Slider(label='Temperature QA', value=0.75, minimum=0, maximum=1, step=0.01)
97
+ num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, 10))
98
+ num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, 10))
99
+
100
+ with gr.Row():
101
+ gen_btn = gr.Button("Generate")
102
+
103
+ @gr.render(inputs=[content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa], triggers=[gen_btn.click])
104
+ def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa):
105
+ qnas = gen(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa)
106
+ df = gr.Dataframe(
107
+ value=[u.values() for u in qnas],
108
+ headers=['Question', 'Answer'],
109
+ col_count=2,
110
+ wrap=True
111
+ )
112
+ pd_df = pd.DataFrame([u.values() for u in qnas], columns=['Question', 'Answer'])
113
+
114
+ download = gr.DownloadButton(label='Download (without headers)', value=create_file_download(pd_df))
115
+
116
+
117
+ demo.launch()