Keshavp08 commited on
Commit
4b0f6e2
·
verified ·
1 Parent(s): 9955d36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import matplotlib.pyplot as plt
5
- import numpy # Importing numpy
6
 
7
  @st.cache_resource
8
  def load_model():
@@ -12,9 +12,11 @@ def load_model():
12
 
13
  tokenizer, model = load_model()
14
 
15
- st.title("Sentiment Analysis App")
16
 
17
  text = st.text_input("Enter text to analyze:")
 
 
18
  if st.button("Analyze") and text:
19
  encoding = tokenizer.encode_plus(text, return_tensors="pt", padding=True, truncation=True)
20
  input_ids = encoding["input_ids"]
@@ -23,19 +25,31 @@ if st.button("Analyze") and text:
23
  with torch.no_grad():
24
  output = model(input_ids, attention_mask)
25
  logits = output.logits.squeeze()
26
-
27
- # Determine the number of sentiment classes from the model output
28
  num_classes = logits.shape[0]
29
  sentiments = ["Very Negative", "Negative", "Neutral", "Positive", "Very Positive"][:num_classes]
30
 
 
 
 
31
  prediction = int(torch.argmax(logits))
32
  sentiment = sentiments[prediction]
33
- st.write(f"Sentiment: {sentiment}")
34
 
35
- values = logits.tolist()
 
36
 
37
  fig, ax = plt.subplots()
38
- ax.bar(sentiments, values, color=plt.cm.viridis_r(numpy.linspace(0.3, 0.7, num_classes)))
39
- ax.set_title("Sentiment Analysis Scores")
40
- ax.set_ylabel("Score")
41
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import matplotlib.pyplot as plt
5
+ import numpy as np
6
 
7
  @st.cache_resource
8
  def load_model():
 
12
 
13
  tokenizer, model = load_model()
14
 
15
+ st.title("Advanced Sentiment Analysis App")
16
 
17
  text = st.text_input("Enter text to analyze:")
18
+ threshold = st.slider("Set sentiment strength threshold:", 0.0, 1.0, 0.5, 0.01)
19
+
20
  if st.button("Analyze") and text:
21
  encoding = tokenizer.encode_plus(text, return_tensors="pt", padding=True, truncation=True)
22
  input_ids = encoding["input_ids"]
 
25
  with torch.no_grad():
26
  output = model(input_ids, attention_mask)
27
  logits = output.logits.squeeze()
28
+
 
29
  num_classes = logits.shape[0]
30
  sentiments = ["Very Negative", "Negative", "Neutral", "Positive", "Very Positive"][:num_classes]
31
 
32
+ softmax = torch.nn.Softmax(dim=0)
33
+ probabilities = softmax(logits).numpy()
34
+
35
  prediction = int(torch.argmax(logits))
36
  sentiment = sentiments[prediction]
37
+ st.write(f"Detected Sentiment: {sentiment}")
38
 
39
+ # Normalize scores for display
40
+ values = probabilities.tolist()
41
 
42
  fig, ax = plt.subplots()
43
+ colors = plt.cm.coolwarm(np.linspace(0, 1, num_classes))
44
+ bars = ax.bar(sentiments, values, color=colors)
45
+
46
+ # Highlight bars that pass the threshold
47
+ for bar, value in zip(bars, values):
48
+ if value > threshold:
49
+ bar.set_alpha(1.0) # Solid color for high confidence
50
+ else:
51
+ bar.set_alpha(0.5) # Faded color for low confidence
52
+
53
+ ax.set_title("Sentiment Analysis Scores with Confidence Threshold")
54
+ ax.set_ylabel("Confidence")
55
+ st.pyplot(fig)