ClefChen commited on
Commit
dc8b598
1 Parent(s): 3441e15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -72,8 +72,12 @@ def get_prediction(inputs):
72
  outputs = model(**inputs)
73
  logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
74
  pred_prob = torch.softmax(logits, dim=1)
75
- pred = torch.argmax(pred_prob, dim=1)
76
- return class_names[pred.item()]
 
 
 
 
77
 
78
  # vectorizer= nltk_u.vectorizer()
79
  # vectorizer.fit(train_data.text)
 
72
  outputs = model(**inputs)
73
  logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
74
  pred_prob = torch.softmax(logits, dim=1)
75
+ pred = torch.argmax(pred_prob, dim=1).item()
76
+ if pred in class_names:
77
+ return class_names[pred]
78
+ else:
79
+ print(f"Warning: Prediction index {pred} not found in class_names.")
80
+ return "Unknown" # 或者其他默认的响应
81
 
82
  # vectorizer= nltk_u.vectorizer()
83
  # vectorizer.fit(train_data.text)