Berbex commited on
Commit
89cadb8
β€’
1 Parent(s): 12f070b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -18
app.py CHANGED
@@ -11,24 +11,6 @@ from transformers import EvalPrediction
11
  import torch
12
  import gradio as gr
13
 
14
- # REMOVE THIS IN COLAB
15
-
16
- title = 'Text market sentiment'
17
- text_ = "Bitcoin to the moon"
18
- model = torch.load("./model.pt", map_location=torch.device('cpu'))
19
-
20
- inp = [gr.Textbox(label='API Key', placeholder="What is your API Key?"), gr.Textbox(label='Audio File URL', placeholder="Audio file URL?")]
21
- out = gr.Textbox(label='Output')
22
- text_button = gr.Button("Flip")
23
- text_button.click(audio_to_text, inputs=inp, outputs=out)
24
-
25
- interface = gr.Interface.load(input=inp,output=out,
26
- title = title,
27
- theme = "peach",
28
- examples = [[text_]]).launch()
29
-
30
- ###############
31
-
32
  console = Console()
33
 
34
  dataset = load_dataset("zeroshot/twitter-financial-news-sentiment", )
@@ -131,6 +113,52 @@ def compute_metrics(p: EvalPrediction):
131
  return result
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  trainer = Trainer(
135
  model,
136
  args,
 
11
  import torch
12
  import gradio as gr
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  console = Console()
15
 
16
  dataset = load_dataset("zeroshot/twitter-financial-news-sentiment", )
 
113
  return result
114
 
115
 
116
+
117
+
118
+
119
+ # REMOVE THIS IN COLAB #############
120
+
121
+ title = 'Text market sentiment'
122
+ text_ = "Bitcoin to the moon"
123
+ model = torch.load("./model.pt", map_location=torch.device('cpu'))
124
+
125
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
126
+
127
+ text = "Bitcoin to the moon"
128
+
129
+ encoding = tokenizer(text, return_tensors="pt")
130
+
131
+ # apply sigmoid + threshold
132
+ sigmoid = torch.nn.Sigmoid()
133
+ probs = sigmoid(logits.squeeze().cpu())
134
+ predictions = np.zeros(probs.shape)
135
+ predictions[np.where(probs >= 0.5)] = 1
136
+ # turn predicted id's into actual label names
137
+ predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
138
+ console.log("a")
139
+ console.log(predicted_labels)
140
+ console.log("a")
141
+
142
+
143
+
144
+
145
+
146
+ inp = [gr.Textbox(label='API Key', placeholder="What is your API Key?"), gr.Textbox(label='Audio File URL', placeholder="Audio file URL?")]
147
+ out = gr.Textbox(label='Output')
148
+ text_button = gr.Button("Flip")
149
+ text_button.click(audio_to_text, inputs=inp, outputs=out)
150
+
151
+ interface = gr.Interface.load(input=inp,output=out,
152
+ title = title,
153
+ theme = "peach",
154
+ examples = [[text_]]).launch()
155
+
156
+ ###############
157
+
158
+
159
+
160
+
161
+
162
  trainer = Trainer(
163
  model,
164
  args,