efeperro commited on
Commit
56f8481
1 Parent(s): 9a45018

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -28,22 +28,24 @@ def load_cnn():
28
  model.eval()
29
 
30
  return model
31
-
32
- def predict_sentiment(text, model, torch=False):
33
  if torch == True:
34
- processed_text = processor.transform(text)
35
- with torch.no_grad(): # Ensure no gradients are computed
36
- prediction = model(processed_text) # Get raw model output
37
- # Convert output to probabilities
38
- probs = torch.softmax(prediction, dim=1)
39
- # Get the predicted class
40
- pred_class = torch.argmax(probs, dim=1)
41
- return pred_class.item() # Return the predicted class as a Python int
 
 
 
 
42
  else:
43
  processor.transform(text)
44
  prediction = model.predict([text])
45
  return prediction
46
-
47
 
48
 
49
  model_1 = load_model()
 
28
  model.eval()
29
 
30
  return model
31
+ def predict_sentiment(text, model, vocab, tokenizer):
 
32
  if torch == True:
33
+ processor.transform(text)
34
+ tokens = tokenizer(text)
35
+ encoded = [vocab[token] for token in tokens]
36
+ input_tensor = torch.tensor(encoded).unsqueeze(0).to(device)
37
+
38
+ with torch.no_grad(): # No gradient needed
39
+ model.eval() # Evaluation mode
40
+ outputs = model(input_tensor)
41
+ probs = torch.softmax(outputs, dim=1)
42
+ pred_class = torch.argmax(probs, dim=1).item()
43
+
44
+ return pred_class # Return the predicted class index
45
  else:
46
  processor.transform(text)
47
  prediction = model.predict([text])
48
  return prediction
 
49
 
50
 
51
  model_1 = load_model()