nppmatt commited on
Commit
6407600
1 Parent(s): 73571fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -95,7 +95,7 @@ option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT
95
  if option == "BERT":
96
  model = PretrainedBertClass()
97
  else:
98
- model = torch.load("pytorch_bert_toxic.bin")
99
 
100
  # Freeze model and input tokens
101
  def inference():
@@ -108,7 +108,7 @@ def inference():
108
  mask = data["mask"].to(device, dtype=torch.long)
109
  token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
110
  targets = data["targets"].to(device, dtype=torch.float)
111
- outputs = model(ids, mask, token_type_ids)
112
  final_targets.extend(targets.cpu().detach().numpy().tolist())
113
  final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
114
  return final_outputs, final_targets
 
95
  if option == "BERT":
96
  model = PretrainedBertClass()
97
  else:
98
+ model = torch.load("pytorch_bert_toxic.bin", map_location=torch.device("cpu"))
99
 
100
  # Freeze model and input tokens
101
  def inference():
 
108
  mask = data["mask"].to(device, dtype=torch.long)
109
  token_type_ids = data["token_type_ids"].to(device, dtype=torch.long)
110
  targets = data["targets"].to(device, dtype=torch.float)
111
+ outputs = model(ids, mask, token_type_ids, return_dict=False)
112
  final_targets.extend(targets.cpu().detach().numpy().tolist())
113
  final_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
114
  return final_outputs, final_targets