hesha commited on
Commit
2b9d722
1 Parent(s): 187b1c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -1,17 +1,29 @@
1
- import gradio as gr
2
  from transformers import pipeline
 
 
 
 
 
3
 
4
- model = 'MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33'
5
- pipe = pipeline('zero-shot-classification', model=model)
 
 
 
6
 
7
- def infer(text, classes, multi_label):
8
- output = pipe(text, classes, multi_label=multi_label)
9
- print(output)
10
- return dict(zip(output['labels'], output['scores']))
 
 
 
 
11
 
12
  text_input = gr.Textbox(lines=5, placeholder='Once upon a time...', label='Text Source', show_label=True)
13
- class_input = gr.Textbox(value='positive,negative', label='Class Label', show_label=True)
14
  allow_multi_label = gr.Checkbox(value=True, label='Multiple True Classes')
 
15
 
16
- app = gr.Interface(fn=infer, inputs=[text_input, class_input, allow_multi_label], outputs='label')
17
  app.launch()
 
 
1
  from transformers import pipeline
2
+ import gradio as gr
3
+ import logging
4
+
5
+ logging.basicConfig(level=logging.INFO)
6
+ logger = logging.getLogger(__name__)
7
 
8
+ pipelines = {
9
+ 'small': pipeline('zero-shot-classification', model='MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33'),
10
+ 'base': pipeline('zero-shot-classification', model='MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33'),
11
+ 'large': pipeline('zero-shot-classification', model='MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33')
12
+ }
13
 
14
+ def infer(text, classes, multi_label, model_size):
15
+ try:
16
+ output = pipelines[model_size](text, classes, multi_label=multi_label)
17
+ logger.info(f"Model size: {model_size}, Output: {output}")
18
+ return dict(zip(output['labels'], output['scores']))
19
+ except Exception as e:
20
+ logger.error(f"Error: {e}")
21
+ return {}
22
 
23
  text_input = gr.Textbox(lines=5, placeholder='Once upon a time...', label='Text Source', show_label=True)
24
+ class_input = gr.Textbox(value='positive, negative', label='Class Label', show_label=True, info='Use commas (,) to seperate classes')
25
  allow_multi_label = gr.Checkbox(value=True, label='Multiple True Classes')
26
+ model_sizes = gr.Radio(choices=['small', 'base', 'large'], value='base', label='Model Sizes', show_label=True)
27
 
28
+ app = gr.Interface(fn=infer, inputs=[text_input, class_input, allow_multi_label, model_sizes], outputs='label')
29
  app.launch()