MatteoFasulo commited on
Commit
6c7f045
Β·
verified Β·
1 Parent(s): da38f38

Dashboard-fixes (#1)

Browse files

- feat: enhanced GUI (63daa8ce6608610cb201ff8faeb918a65aa58d94)
- bug: removed share=True (ed1a96ef44b5680c941da2d9e14e98e4314f30a2)

Files changed (1) hide show
  1. app.py +114 -82
app.py CHANGED
@@ -16,109 +16,102 @@ examples = [
16
  ["Boxing Day ambush & flagship attack Putin has long tried to downplay the true losses his army has faced in the Black Sea."],
17
  ]
18
 
19
- # Custom model class for combining sentiment analysis with subjectivity detection
20
  class CustomModel(PreTrainedModel):
21
  config_class = DebertaV2Config
22
-
23
  def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
24
  super().__init__(config, *args, **kwargs)
25
  self.deberta = DebertaV2Model(config)
26
  self.pooler = ContextPooler(config)
27
  output_dim = self.pooler.output_dim
28
  self.dropout = nn.Dropout(0.1)
29
-
30
  self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
31
 
32
  def forward(self, input_ids, positive, neutral, negative, token_type_ids=None, attention_mask=None, labels=None):
33
  outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
34
-
35
  encoder_layer = outputs[0]
36
  pooled_output = self.pooler(encoder_layer)
37
-
38
- # Sentiment features as a single tensor
39
- sentiment_features = torch.stack((positive, neutral, negative), dim=1) # Shape: (batch_size, 3)
40
-
41
- # Combine CLS embedding with sentiment features
42
  combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
43
-
44
- # Classification head
45
  logits = self.classifier(self.dropout(combined_features))
46
-
47
  return {'logits': logits}
48
 
49
- # Load the pre-trained tokenizer
50
  def load_tokenizer(model_name: str):
51
  return AutoTokenizer.from_pretrained(model_name)
52
 
53
- # Load the pre-trained model
54
  def load_model(model_name: str):
55
-
56
- if 'sentiment' in model_name:
57
- config = DebertaV2Config.from_pretrained(
58
- model_name,
59
- num_labels=2,
60
- id2label={0: 'OBJ', 1: 'SUBJ'},
61
- label2id={'OBJ': 0, 'SUBJ': 1},
62
- output_attentions=False,
63
- output_hidden_states=False
64
- )
65
-
66
- model = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(model_name)
67
-
68
- else:
69
- model = AutoModelForSequenceClassification.from_pretrained(
70
- model_name,
71
- num_labels=2,
72
- id2label={0: 'OBJ', 1: 'SUBJ'},
73
- label2id={'OBJ': 0, 'SUBJ': 1},
74
- output_attentions=False,
75
- output_hidden_states=False
76
  )
 
 
 
 
 
 
 
 
77
 
78
- return model
79
-
80
- # Get sentiment values using a pre-trained sentiment analysis model
81
  def get_sentiment_values(text: str):
82
- pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None)
83
- sentiments = pipe(text)[0]
84
- return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def analyze(text):
87
- # Extract sentiment values
88
- sentiment_values = get_sentiment_values(text)
 
 
 
 
 
89
 
90
- # Load the tokenizer and model
91
  tokenizer = load_tokenizer(model_card)
92
  model_with_sentiment = load_model(sentiment_model)
93
  model_without_sentiment = load_model(subjectivity_only_model)
94
 
95
- # Tokenize
96
- inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
 
 
97
 
98
- # Get the subjectivity model outputs
99
- outputs_base = model_without_sentiment(**inputs)
100
  logits_base = outputs_base.get('logits')
101
- # Calculate probabilities using softmax
102
  prob_base = torch.nn.functional.softmax(logits_base, dim=1)[0]
 
 
 
 
 
103
 
104
- # Get the sentiment values
105
- positive = sentiment_values['positive']
106
- neutral = sentiment_values['neutral']
107
- negative = sentiment_values['negative']
108
-
109
- # Convert sentiment values to tensors
110
- inputs['positive'] = torch.tensor(positive).unsqueeze(0)
111
- inputs['neutral'] = torch.tensor(neutral).unsqueeze(0)
112
- inputs['negative'] = torch.tensor(negative).unsqueeze(0)
113
-
114
- # Get the sentiment model outputs
115
- outputs_sentiment = model_with_sentiment(**inputs)
116
- logits_sentiment = outputs_sentiment.get('logits')
117
 
118
- # Calculate probabilities using softmax
 
 
119
  prob_sentiment = torch.nn.functional.softmax(logits_sentiment, dim=1)[0]
120
 
121
- # Prepare data for the Dataframe (string values)
122
  table_data = [
123
  ["Positive", f"{positive:.2%}"],
124
  ["Neutral", f"{neutral:.2%}"],
@@ -128,31 +121,70 @@ def analyze(text):
128
  ["TextOnly OBJ", f"{prob_base[0]:.2%}"],
129
  ["TextOnly SUBJ", f"{prob_base[1]:.2%}"]
130
  ]
131
-
132
  return table_data
133
 
134
- # Update the Gradio interface
135
- with gr.Blocks(theme=gr.themes.Base()) as demo:
136
- gr.Markdown("πŸš€ Advanced Subjectivity & Sentiment Dashboard πŸš€")
137
- with gr.Row():
138
- txt = gr.Textbox(label="Enter text to analyze", placeholder="Paste news sentence here...", lines=2)
139
- btn = gr.Button("Analyze πŸ”", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  with gr.Tabs():
141
  with gr.TabItem("Raw Scores πŸ“‹"):
142
- table = gr.Dataframe(headers=["Metric", "Value"], datatype=["str","str"], interactive=False)
 
 
 
 
143
  with gr.TabItem("About ℹ️"):
144
- gr.Markdown("This dashboard uses two DeBERTa-based models (with and without sentiment integration) to detect subjectivity, alongside sentiment scores from an XLM-RoBERTa model.")
 
 
 
 
145
  with gr.Row():
146
  gr.Markdown("### Examples:")
147
- gr.Examples(
148
- examples=examples,
149
- inputs=txt,
150
- outputs=[table],
151
- fn=analyze,
152
- label="Examples",
153
- cache_examples=True,
154
- )
155
- # Link inputs to outputs
 
 
156
  btn.click(fn=analyze, inputs=txt, outputs=[table])
157
 
 
 
 
 
 
 
 
158
  demo.queue().launch()
 
16
  ["Boxing Day ambush & flagship attack Putin has long tried to downplay the true losses his army has faced in the Black Sea."],
17
  ]
18
 
 
19
  class CustomModel(PreTrainedModel):
20
  config_class = DebertaV2Config
 
21
  def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
22
  super().__init__(config, *args, **kwargs)
23
  self.deberta = DebertaV2Model(config)
24
  self.pooler = ContextPooler(config)
25
  output_dim = self.pooler.output_dim
26
  self.dropout = nn.Dropout(0.1)
 
27
  self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
28
 
29
  def forward(self, input_ids, positive, neutral, negative, token_type_ids=None, attention_mask=None, labels=None):
30
  outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
 
31
  encoder_layer = outputs[0]
32
  pooled_output = self.pooler(encoder_layer)
33
+ sentiment_features = torch.stack((positive, neutral, negative), dim=1).to(pooled_output.dtype)
 
 
 
 
34
  combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
 
 
35
  logits = self.classifier(self.dropout(combined_features))
 
36
  return {'logits': logits}
37
 
 
38
  def load_tokenizer(model_name: str):
39
  return AutoTokenizer.from_pretrained(model_name)
40
 
41
+ load_model_cache = {}
42
  def load_model(model_name: str):
43
+ if model_name not in load_model_cache:
44
+ print(f"Loading model: {model_name}")
45
+ if 'sentiment' in model_name:
46
+ config = DebertaV2Config.from_pretrained(
47
+ model_name, num_labels=2, id2label={0: 'OBJ', 1: 'SUBJ'}, label2id={'OBJ': 0, 'SUBJ': 1},
48
+ output_attentions=False, output_hidden_states=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  )
50
+ model_instance = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(model_name)
51
+ else:
52
+ model_instance = AutoModelForSequenceClassification.from_pretrained(
53
+ model_name, num_labels=2, id2label={0: 'OBJ', 1: 'SUBJ'}, label2id={'OBJ': 0, 'SUBJ': 1},
54
+ output_attentions=False, output_hidden_states=False
55
+ )
56
+ load_model_cache[model_name] = model_instance
57
+ return load_model_cache[model_name]
58
 
59
+ sentiment_pipeline_cache = None #
 
 
60
  def get_sentiment_values(text: str):
61
+ global sentiment_pipeline_cache
62
+ if sentiment_pipeline_cache is None:
63
+ print("Loading sentiment pipeline...")
64
+ sentiment_pipeline_cache = pipeline(
65
+ "sentiment-analysis",
66
+ model="cardiffnlp/twitter-xlm-roberta-base-sentiment",
67
+ tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment",
68
+ top_k=None
69
+ )
70
+ sentiments_output = sentiment_pipeline_cache(text)
71
+ if sentiments_output and isinstance(sentiments_output, list) and sentiments_output[0]:
72
+ sentiments = sentiments_output[0]
73
+ return {s['label'].lower(): s['score'] for s in sentiments}
74
+ return {}
75
+
76
 
77
  def analyze(text):
78
+ if not text or not text.strip():
79
+ empty_data = [
80
+ ["Positive", ""], ["Neutral", ""], ["Negative", ""],
81
+ ["Sent-Subj OBJ", ""], ["Sent-Subj SUBJ", ""],
82
+ ["TextOnly OBJ", ""], ["TextOnly SUBJ", ""]
83
+ ]
84
+ return empty_data
85
 
86
+ sentiment_values = get_sentiment_values(text)
87
  tokenizer = load_tokenizer(model_card)
88
  model_with_sentiment = load_model(sentiment_model)
89
  model_without_sentiment = load_model(subjectivity_only_model)
90
 
91
+ inputs_dict = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
92
+
93
+ device = next(model_without_sentiment.parameters()).device
94
+ inputs_dict_on_device = {k: v.to(device) for k, v in inputs_dict.items()}
95
 
96
+ outputs_base = model_without_sentiment(**inputs_dict_on_device)
 
97
  logits_base = outputs_base.get('logits')
 
98
  prob_base = torch.nn.functional.softmax(logits_base, dim=1)[0]
99
+
100
+ positive = sentiment_values.get('positive', 0.0)
101
+ neutral = sentiment_values.get('neutral', 0.0)
102
+ negative = sentiment_values.get('negative', 0.0)
103
+
104
 
105
+ current_inputs_for_sentiment_model = inputs_dict_on_device.copy()
106
+ current_inputs_for_sentiment_model['positive'] = torch.tensor(positive, device=device).unsqueeze(0).float()
107
+ current_inputs_for_sentiment_model['neutral'] = torch.tensor(neutral, device=device).unsqueeze(0).float()
108
+ current_inputs_for_sentiment_model['negative'] = torch.tensor(negative, device=device).unsqueeze(0).float()
 
 
 
 
 
 
 
 
 
109
 
110
+
111
+ outputs_sentiment = model_with_sentiment(**current_inputs_for_sentiment_model)
112
+ logits_sentiment = outputs_sentiment.get('logits')
113
  prob_sentiment = torch.nn.functional.softmax(logits_sentiment, dim=1)[0]
114
 
 
115
  table_data = [
116
  ["Positive", f"{positive:.2%}"],
117
  ["Neutral", f"{neutral:.2%}"],
 
121
  ["TextOnly OBJ", f"{prob_base[0]:.2%}"],
122
  ["TextOnly SUBJ", f"{prob_base[1]:.2%}"]
123
  ]
 
124
  return table_data
125
 
126
+ def load_default_example_on_startup():
127
+ print("Loading default example on startup...")
128
+ if examples and examples[0] and isinstance(examples[0], list) and examples[0]:
129
+ default_text = examples[0][0]
130
+ default_analysis_results = analyze(default_text)
131
+ return default_text, default_analysis_results
132
+ print("Warning: No valid default example found. Loading empty.")
133
+ empty_text = ""
134
+ empty_results = analyze(empty_text)
135
+ return empty_text, empty_results
136
+
137
+ with gr.Blocks(theme=gr.themes.Ocean(), title="Subjectivity & Sentiment Dashboard") as demo:
138
+ gr.Markdown("# πŸš€ Subjectivity & Sentiment Analysis Dashboard πŸš€")
139
+
140
+ with gr.Column():
141
+ txt = gr.Textbox(
142
+ label="Enter text to analyze",
143
+ placeholder="Paste news sentence here...",
144
+ lines=2,
145
+ )
146
+ with gr.Row():
147
+ gr.Column(scale=1, min_width=0)
148
+ btn = gr.Button(
149
+ "Analyze πŸ”",
150
+ variant="primary",
151
+ size="md",
152
+ scale=0
153
+ )
154
+
155
  with gr.Tabs():
156
  with gr.TabItem("Raw Scores πŸ“‹"):
157
+ table = gr.Dataframe(
158
+ headers=["Metric", "Value"],
159
+ datatype=["str", "str"],
160
+ interactive=False
161
+ )
162
  with gr.TabItem("About ℹ️"):
163
+ gr.Markdown(
164
+ "This dashboard uses two DeBERTa-based models (with and without sentiment integration) "
165
+ "to detect subjectivity, alongside sentiment scores from an XLM-RoBERTa model."
166
+ )
167
+
168
  with gr.Row():
169
  gr.Markdown("### Examples:")
170
+
171
+
172
+ gr.Examples(
173
+ examples=examples,
174
+ inputs=txt,
175
+ outputs=[table],
176
+ fn=analyze,
177
+ label="Click an example to analyze",
178
+ cache_examples=True,
179
+ )
180
+
181
  btn.click(fn=analyze, inputs=txt, outputs=[table])
182
 
183
+
184
+ demo.load(
185
+ fn=load_default_example_on_startup,
186
+ inputs=None,
187
+ outputs=[txt, table]
188
+ )
189
+
190
  demo.queue().launch()