Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -95,7 +95,7 @@ option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT
|
|
95 |
if option == "BERT":
|
96 |
model = PretrainedBertClass()
|
97 |
else:
|
98 |
-
model = torch.load("pytorch_bert_toxic.bin")
|
99 |
|
100 |
# Freeze model and input tokens
|
101 |
def inference():
|
@@ -108,7 +108,7 @@ def inference():
|
|
108 |
mask = data["mask"].to(device, dtype=torch.long)
|
109 |
token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
|
110 |
targets = data["targets"].to(device, dtype=torch.float)
|
111 |
-
outputs = model(ids, mask, token_type_ids)
|
112 |
final_targets.extend(targets.cpu().detach().numpy().tolist())
|
113 |
final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
|
114 |
return final_outputs, final_targets
|
|
|
95 |
if option == "BERT":
|
96 |
model = PretrainedBertClass()
|
97 |
else:
|
98 |
+
model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device("cpu"))
|
99 |
|
100 |
# Freeze model and input tokens
|
101 |
def inference():
|
|
|
108 |
mask = data["mask"].to(device, dtype=torch.long)
|
109 |
token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
|
110 |
targets = data["targets"].to(device, dtype=torch.float)
|
111 |
+
outputs = model(ids, mask, token_type_ids, return_dict=False)
|
112 |
final_targets.extend(targets.cpu().detach().numpy().tolist())
|
113 |
final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
|
114 |
return final_outputs, final_targets
|