ivyblossom commited on
Commit
1e5415d
1 Parent(s): f7d8474

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -11
app.py CHANGED
@@ -5,16 +5,6 @@ import torch
5
  # Set up the device (GPU or CPU)
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- # Function to perform sentiment analysis
9
- def perform_sentiment_analysis(text):
10
- inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
11
- inputs = inputs.to(device)
12
- outputs = model(**inputs)
13
- logits = outputs.logits
14
- probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
15
- sentiment_label = "Positive" if probabilities[1] > probabilities[0] else "Negative"
16
- return sentiment_label, probabilities
17
-
18
  # Streamlit app
19
  def main():
20
  st.title("Sentiment Analysis App")
@@ -36,7 +26,13 @@ def main():
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
 
38
  if st.button("Analyze"):
39
- sentiment_label, probabilities = perform_sentiment_analysis(text)
 
 
 
 
 
 
40
  st.write(f"Sentiment: {sentiment_label}")
41
  st.write(f"Positive probability: {probabilities[1]}")
42
  st.write(f"Negative probability: {probabilities[0]}")
 
5
  # Set up the device (GPU or CPU)
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
 
 
 
 
 
 
 
 
 
 
8
  # Streamlit app
9
  def main():
10
  st.title("Sentiment Analysis App")
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
 
28
  if st.button("Analyze"):
29
+ # Perform sentiment analysis
30
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
31
+ inputs = inputs.to(device)
32
+ outputs = model(**inputs)
33
+ logits = outputs.logits
34
+ probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
35
+ sentiment_label = "Positive" if probabilities[1] > probabilities[0] else "Negative"
36
  st.write(f"Sentiment: {sentiment_label}")
37
  st.write(f"Positive probability: {probabilities[1]}")
38
  st.write(f"Negative probability: {probabilities[0]}")