cooelf commited on
Commit
b104a94
1 Parent(s): c0f9bdd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ import gradio as gr
3
+ import requests
4
+ import torch
5
+ from transformers import (
6
+ AutoConfig,
7
+ AutoModelForSequenceClassification,
8
+ AutoTokenizer,
9
+ )
10
+
11
+ model_dir = "my-bert-model"
12
+
13
+ config = AutoConfig.from_pretrained(model_dir, num_labels=2, finetuning_task="text-classification")
14
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
15
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config)
16
+
17
+ def inference(input_text):
18
+ inputs = tokenizer.batch_encode_plus(
19
+ [input_text],
20
+ max_length=512,
21
+ pad_to_max_length=True,
22
+ truncation=True,
23
+ padding="max_length",
24
+ return_tensors="pt",
25
+ )
26
+
27
+ with torch.no_grad():
28
+ logits = model(**inputs).logits
29
+
30
+ predicted_class_id = logits.argmax().item()
31
+ output = model.config.id2label[predicted_class_id]
32
+ return output
33
+
34
+ with gr.Blocks(css="""
35
+ .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
36
+ #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
37
+ """) as demo:
38
+ with gr.Column(elem_id="container"):
39
+ with gr.Row():
40
+ with gr.Row():
41
+ input_text = gr.Textbox(
42
+ placeholder="Insert your prompt here:", scale=5, container=False
43
+ )
44
+ answer = gr.Textbox(lines=0, label="Answer")
45
+ generate_bt = gr.Button("Generate", scale=1)
46
+ inputs = [input_text]
47
+ outputs = [answer]
48
+ generate_bt.click(
49
+ fn=inference, inputs=inputs, outputs=outputs, show_progress=False
50
+ )
51
+ demo.queue()
52
+ demo.launch()