emielclopterop commited on
Commit
e03a30f
1 Parent(s): 25f8ce5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -86
app.py CHANGED
@@ -1,91 +1,213 @@
1
 
2
-
3
  import gradio as gr
4
  from transformers import pipeline
5
 
6
- #pipelines init
7
- qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
8
-
9
- classification_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
10
-
11
- translation_pipeline = pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr")
12
-
13
- topic_classification_pipeline = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english") # Fine-tuned model for topic classification
14
-
15
- summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
16
-
17
- #functions defining
18
-
19
- def answer_question(context, question):
20
- return qa_pipeline(question=question, context=context)["answer"]
21
-
22
- def classify_text(text, labels):
23
- labels = labels.split(",")
24
- results = classification_pipeline(text, candidate_labels=labels)
25
- return {label: float(f"{prob:.4f}") for label, prob in zip(results["labels"], results["scores"])}
26
-
27
- def translate_text(text):
28
- return translation_pipeline(text)[0]['translation_text'] if text else "No translation"
29
-
30
- def classify_topic(text):
31
- results = topic_classification_pipeline(text)
32
- return ", ".join([f"{result['label']}: {result['score']:.4f}" for result in results])
33
-
34
- def summarize_text(text):
35
- result = summarization_pipeline(text, max_length=60)
36
- return result[0]['summary_text'] if result else "No summary available"
37
-
38
- def multi_model_interaction(text):
39
-
40
- summary = summarize_text(text)
41
- translated_summary = translate_text(summary)
42
-
43
- return {
44
- "Summary (English)": summary,
45
- "Summary (French)": translated_summary,
46
- }
47
-
48
-
49
- #Blocking interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with gr.Blocks() as demo:
51
- with gr.Tab("Single Models"):
52
- with gr.Column():
53
- gr.Markdown("### Question")
54
- context = gr.Textbox(label="Context")
55
- question = gr.Textbox(label="Question")
56
- answer_output = gr.Text(label="Answer")
57
- gr.Button("Answer").click(answer_question, inputs=[context, question], outputs=answer_output)
58
-
59
- with gr.Column():
60
- gr.Markdown("### Zero-Shot")
61
- text_zsc = gr.Textbox(label="Text")
62
- labels = gr.Textbox(label="Labels (comma separated)")
63
- classification_result = gr.JSON(label="Classification Results")
64
- gr.Button("Classify").click(classify_text, inputs=[text_zsc, labels], outputs=classification_result)
65
-
66
- with gr.Column():
67
- gr.Markdown("### Translation")
68
- text_to_translate = gr.Textbox(label="Text")
69
- translated_text = gr.Text(label="Translated Text")
70
- gr.Button("Translate").click(translate_text, inputs=[text_to_translate], outputs=translated_text)
71
-
72
- with gr.Column():
73
- gr.Markdown("### Summarization")
74
- text_to_summarize = gr.Textbox(label="Text")
75
- summary = gr.Text(label="Summary")
76
- gr.Button("Summarize").click(summarize_text, inputs=[text_to_summarize], outputs=summary)
77
-
78
- with gr.Column():
79
- gr.Markdown("### Sentiment Analysis")
80
- text_for_sentiment = gr.Textbox(label="Text for Sentiment Analysis")
81
- sentiment_result = gr.Text(label="Sentiment")
82
- gr.Button("Classify Sentiment").click(classify_topic, inputs=[text_for_sentiment], outputs=sentiment_result)
83
-
84
- with gr.Tab("Multi-Model"):
85
- gr.Markdown("### Multi-Model")
86
- input_text = gr.Textbox(label="Enter Text for Multi-Model Analysis")
87
- multi_output = gr.Text(label="Results")
88
- gr.Button("Process").click(multi_model_interaction, inputs=[input_text], outputs=multi_output)
89
-
90
- #Launching demo
91
- demo.launch(share=True, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
 
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
+ # Define the necessary pipelines
6
+ def load_qa_model():
7
+ return pipeline("question-answering", model="bert-large-uncased-whole-word-masking-finetuned-squad")
8
+
9
+ def load_classifier_model():
10
+ return pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33")
11
+
12
+ def load_translator_model(target_language):
13
+ try:
14
+ model_name = f"Helsinki-NLP/opus-mt-en-{target_language}"
15
+ return pipeline("translation", model=model_name)
16
+ except Exception as e:
17
+ print(f"Error loading translation model: {e}")
18
+ return None
19
+
20
+ def load_generator_model():
21
+ try:
22
+ return pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B", tokenizer="EleutherAI/gpt-neo-2.7B")
23
+ except Exception as e:
24
+ print(f"Error loading text generation model: {e}")
25
+ return None
26
+
27
+ def load_summarizer_model():
28
+ try:
29
+ return pipeline("summarization", model="facebook/bart-large-cnn")
30
+ except Exception as e:
31
+ print(f"Error loading summarization model: {e}")
32
+ return None
33
+
34
+ # Define the functions for processing
35
+ def process_qa(context, question):
36
+ qa_model = load_qa_model()
37
+ try:
38
+ return qa_model(context=context, question=question)["answer"]
39
+ except Exception as e:
40
+ print(f"Error during question answering: {e}")
41
+ return "Error during question answering"
42
+
43
+ def process_classifier(text, labels):
44
+ classifier_model = load_classifier_model()
45
+ try:
46
+ return classifier_model(text, labels)["labels"][0]
47
+ except Exception as e:
48
+ print(f"Error during classification: {e}")
49
+ return "Error during classification"
50
+
51
+ def process_translation(text, target_language):
52
+ translator_model = load_translator_model(target_language)
53
+ if translator_model:
54
+ try:
55
+ return translator_model(text)[0]["translation_text"]
56
+ except Exception as e:
57
+ print(f"Error during translation: {e}")
58
+ return "Translation error"
59
+ return "Model loading error"
60
+
61
+ def process_generation(prompt):
62
+ generator_model = load_generator_model()
63
+ if generator_model:
64
+ if prompt.strip():
65
+ try:
66
+ return generator_model(prompt, max_length=50)[0]["generated_text"]
67
+ except Exception as e:
68
+ print(f"Error during text generation: {e}")
69
+ return "Text generation error"
70
+ else:
71
+ return "Prompt is empty"
72
+ return "Model loading error"
73
+
74
+ def process_summarization(text):
75
+ summarizer_model = load_summarizer_model()
76
+ if summarizer_model:
77
+ if text.strip():
78
+ try:
79
+ return summarizer_model(text, max_length=150, min_length=40, do_sample=False)[0]["summary_text"]
80
+ except Exception as e:
81
+ print(f"Error during summarization: {e}")
82
+ return "Summarization error"
83
+ else:
84
+ return "Text is empty"
85
+ return "Model loading error"
86
+
87
+ # Gradio Interface
88
  with gr.Blocks() as demo:
89
+ gr.Markdown("Choose an NLP task and input the required text.")
90
+
91
+ with gr.Tab("Single-Models"):
92
+ gr.Markdown("This tab is for single models demonstration.")
93
+
94
+ task_select_single = gr.Dropdown(["Question Answering", "Zero-Shot Classification", "Translation", "Text Generation", "Summarization"], label="Select Task")
95
+ input_text_single = gr.Textbox(label="Input Text")
96
+
97
+ # Additional inputs for specific tasks
98
+ context_input_single = gr.Textbox(label="Context", visible=False)
99
+ label_input_single = gr.CheckboxGroup(["positive", "negative", "neutral"], label="Labels", visible=False)
100
+ target_language_input_single = gr.Dropdown(["nl", "fr", "es", "de"], label="Target Language", visible=False)
101
+
102
+ output_text_single = gr.Textbox(label="Output")
103
+ execute_button_single = gr.Button("Execute")
104
+
105
+ def update_inputs(task):
106
+ if task == "Question Answering":
107
+ return {
108
+ context_input_single: gr.update(visible=True),
109
+ label_input_single: gr.update(visible=False),
110
+ target_language_input_single: gr.update(visible=False)
111
+ }
112
+ elif task == "Zero-Shot Classification":
113
+ return {
114
+ context_input_single: gr.update(visible=False),
115
+ label_input_single: gr.update(visible=True),
116
+ target_language_input_single: gr.update(visible=False)
117
+ }
118
+ elif task == "Translation":
119
+ return {
120
+ context_input_single: gr.update(visible=False),
121
+ label_input_single: gr.update(visible=False),
122
+ target_language_input_single: gr.update(visible=True)
123
+ }
124
+ elif task == "Text Generation":
125
+ return {
126
+ context_input_single: gr.update(visible=False),
127
+ label_input_single: gr.update(visible=False),
128
+ target_language_input_single: gr.update(visible=False)
129
+ }
130
+ elif task == "Summarization":
131
+ return {
132
+ context_input_single: gr.update(visible=False),
133
+ label_input_single: gr.update(visible=False),
134
+ target_language_input_single: gr.update(visible=False)
135
+ }
136
+ else:
137
+ return {
138
+ context_input_single: gr.update(visible=False),
139
+ label_input_single: gr.update(visible=False),
140
+ target_language_input_single: gr.update(visible=False)
141
+ }
142
+
143
+ task_select_single.change(fn=update_inputs, inputs=task_select_single,
144
+ outputs=[context_input_single, label_input_single, target_language_input_single])
145
+
146
+ def execute_task_single(task, input_text, context, labels, target_language):
147
+ if task == "Question Answering":
148
+ return process_qa(context=context, question=input_text)
149
+ elif task == "Zero-Shot Classification":
150
+ if not labels:
151
+ return "Please provide labels for classification."
152
+ return process_classifier(text=input_text, labels=labels)
153
+ elif task == "Translation":
154
+ if not target_language:
155
+ return "Please select a target language for translation."
156
+ return process_translation(text=input_text, target_language=target_language)
157
+ elif task == "Text Generation":
158
+ return process_generation(prompt=input_text)
159
+ elif task == "Summarization":
160
+ return process_summarization(text=input_text)
161
+ else:
162
+ return "Invalid task selected."
163
+
164
+ execute_button_single.click(
165
+ execute_task_single,
166
+ inputs=[task_select_single, input_text_single, context_input_single, label_input_single, target_language_input_single],
167
+ outputs=output_text_single
168
+ )
169
+
170
+ with gr.Tab("Multi-Model Task"):
171
+ gr.Markdown("This tab allows you to execute all tasks sequentially.")
172
+
173
+ # Inputs for all tasks
174
+ input_text_multi = gr.Textbox(label="Input Text")
175
+ context_input_multi = gr.Textbox(label="Context (for QA)")
176
+ label_input_multi = gr.CheckboxGroup(["positive", "negative", "neutral"], label="Labels (for Classification)")
177
+ target_language_input_multi = gr.Dropdown(["nl", "fr", "es", "de"], label="Target Language (for Translation)")
178
+
179
+ # Outputs for all tasks
180
+ output_qa = gr.Textbox(label="QA Output")
181
+ output_classification = gr.Textbox(label="Classification Output")
182
+ output_translation = gr.Textbox(label="Translation Output")
183
+ output_generation = gr.Textbox(label="Text Generation Output")
184
+ output_summarization = gr.Textbox(label="Summarization Output")
185
+
186
+ execute_button_multi = gr.Button("Execute All Tasks")
187
+
188
+ def execute_all_tasks(input_text, context, labels, target_language):
189
+ # Process Question Answering
190
+ qa_output = process_qa(context=context, question=input_text)
191
+
192
+ # Process Classification
193
+ classification_output = process_classifier(text=input_text, labels=labels)
194
+
195
+ # Process Translation
196
+ translation_output = process_translation(text=input_text, target_language=target_language)
197
+
198
+ # Process Text Generation using QA output
199
+ generation_output = process_generation(prompt=qa_output)
200
+
201
+ # Process Summarization using Text Generation output
202
+ summarization_output = process_summarization(text=generation_output)
203
+
204
+ # Return all outputs
205
+ return qa_output, classification_output, translation_output, generation_output, summarization_output
206
+
207
+ execute_button_multi.click(
208
+ execute_all_tasks,
209
+ inputs=[input_text_multi, context_input_multi, label_input_multi, target_language_input_multi],
210
+ outputs=[output_qa, output_classification, output_translation, output_generation, output_summarization]
211
+ )
212
+
213
+ demo.launch()