Update to use new and improved bert model
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ def predict(input_query):
|
|
50 |
|
51 |
|
52 |
textbox = gr.Textbox(label="Query",
|
53 |
-
placeholder="
|
54 |
label = gr.Label(label="Result", num_top_classes=5)
|
55 |
|
56 |
gradio_app = gr.Interface(
|
@@ -59,7 +59,7 @@ gradio_app = gr.Interface(
|
|
59 |
outputs=[label],
|
60 |
title="Query Classification",
|
61 |
allow_flagging="manual",
|
62 |
-
flagging_options=["
|
63 |
flagging_callback=flag_callback,
|
64 |
)
|
65 |
|
|
|
50 |
|
51 |
|
52 |
textbox = gr.Textbox(label="Query",
|
53 |
+
placeholder="Quick bite to eat near me")
|
54 |
label = gr.Label(label="Result", num_top_classes=5)
|
55 |
|
56 |
gradio_app = gr.Interface(
|
|
|
59 |
outputs=[label],
|
60 |
title="Query Classification",
|
61 |
allow_flagging="manual",
|
62 |
+
flagging_options=["correct classification", "incorrect classification"],
|
63 |
flagging_callback=flag_callback,
|
64 |
)
|
65 |
|
hydra.py
CHANGED
@@ -35,7 +35,7 @@ class Hydra(BertModel):
|
|
35 |
super().__init__(config)
|
36 |
self.config = config
|
37 |
self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
|
38 |
-
self.
|
39 |
[len(group) for group in config.label_groups]))
|
40 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
41 |
|
@@ -70,7 +70,7 @@ class Hydra(BertModel):
|
|
70 |
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
71 |
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
|
72 |
pooled_output = self.dropout(pooled_output) # (bs, dim)
|
73 |
-
logits = self.
|
74 |
|
75 |
loss = None
|
76 |
if labels is not None:
|
@@ -107,6 +107,6 @@ class Hydra(BertModel):
|
|
107 |
def to(self, device):
|
108 |
super().to(device)
|
109 |
self.pre_classifier.to(device)
|
110 |
-
self.
|
111 |
self.dropout.to(device)
|
112 |
return self
|
|
|
35 |
super().__init__(config)
|
36 |
self.config = config
|
37 |
self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
|
38 |
+
self.classifier = nn.Linear(config.hidden_size, sum(
|
39 |
[len(group) for group in config.label_groups]))
|
40 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
41 |
|
|
|
70 |
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
71 |
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
|
72 |
pooled_output = self.dropout(pooled_output) # (bs, dim)
|
73 |
+
logits = self.classifier(pooled_output) # (bs, num_labels)
|
74 |
|
75 |
loss = None
|
76 |
if labels is not None:
|
|
|
107 |
def to(self, device):
|
108 |
super().to(device)
|
109 |
self.pre_classifier.to(device)
|
110 |
+
self.classifier.to(device)
|
111 |
self.dropout.to(device)
|
112 |
return self
|