antenmanuuel commited on
Commit
03de5c7
·
verified ·
1 Parent(s): efa23c6

Upload folder using huggingface_hub

Browse files
attention_comparison_helpers.py CHANGED
@@ -8,13 +8,18 @@ from routes.tokenize import tokenize_text
8
 
9
  async def get_attention_comparison_bert(request: ComparisonRequest):
10
  """
11
- BERT-specific implementation of attention comparison
12
  """
13
  try:
14
- print(f"\n=== USING BERT ATTENTION COMPARISON IMPLEMENTATION ===")
 
15
 
16
  # 1. Get the "before" attention data
17
- before_attention_request = AttentionRequest(text=request.text, model_name=request.model_name)
 
 
 
 
18
  before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
19
 
20
  # 2. Tokenize the text
@@ -46,7 +51,7 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
46
 
47
  # HANDLE PUNCTUATION
48
  if is_punctuation:
49
- print(f"\nUsing BERT punctuation replacement approach")
50
 
51
  # Find all occurrences of this punctuation in the original text
52
  punctuation_positions = [pos for pos, char in enumerate(original_text) if char == selected_token]
@@ -61,8 +66,9 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
61
  # We'll use a heuristic based on the token's position
62
 
63
  # Count non-special tokens before our selected token
 
64
  non_special_tokens_before = sum(1 for t in tokens[:request.masked_index]
65
- if t["text"] not in ["[CLS]", "[SEP]"])
66
 
67
  # Select the corresponding punctuation position (or last one if out of range)
68
  punct_idx = min(non_special_tokens_before, len(punctuation_positions) - 1)
@@ -76,14 +82,18 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
76
  print(f"Replaced text: '{replaced_text}'")
77
 
78
  # Get the after attention data
79
- after_attention_request = AttentionRequest(text=replaced_text, model_name=request.model_name)
 
 
 
 
80
  after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
81
 
82
  # Return comparison data
83
  return {"before_attention": before_data, "after_attention": after_data}
84
 
85
- # HANDLE REGULAR WORDS FOR BERT
86
- print(f"\nUsing BERT word replacement approach")
87
  words = original_text.split()
88
  print(f"Words: {words}")
89
 
@@ -193,7 +203,7 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
193
  print(f"Could not determine token position, using simple word replacement")
194
  words = original_text.split()
195
 
196
- # Adjust for special tokens in BERT ([CLS])
197
  adjusted_index = max(0, request.masked_index - 1)
198
  word_idx = min(adjusted_index, len(words) - 1)
199
 
@@ -218,14 +228,18 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
218
  print(f"Replaced text: '{replaced_text}'")
219
 
220
  # Get the after attention data
221
- after_attention_request = AttentionRequest(text=replaced_text, model_name=request.model_name)
 
 
 
 
222
  after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
223
 
224
  # Return comparison data
225
  return {"before_attention": before_data, "after_attention": after_data}
226
 
227
  except Exception as e:
228
- print(f"BERT Attention comparison error: {str(e)}")
229
  import traceback
230
  traceback.print_exc()
231
  raise HTTPException(status_code=500, detail=str(e))
@@ -243,7 +257,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
243
  print(f"Replacement word: '{request.replacement_word}'")
244
 
245
  # Get the "before" attention data
246
- before_attention_request = AttentionRequest(text=request.text, model_name=request.model_name)
 
 
 
 
247
  before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
248
 
249
  # Tokenize the text
@@ -302,7 +320,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
302
  print(f"Replaced text: '{replaced_text}'")
303
 
304
  # Get the "after" attention data and return
305
- after_request = AttentionRequest(text=replaced_text, model_name=request.model_name)
 
 
 
 
306
  after_data = (await get_attention_matrices(after_request))["attention_data"]
307
  return {"before_attention": before_data, "after_attention": after_data}
308
 
@@ -341,7 +363,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
341
  print(f"Replaced text: '{replaced_text}'")
342
 
343
  # Get the "after" attention data
344
- after_request = AttentionRequest(text=replaced_text, model_name=request.model_name)
 
 
 
 
345
  after_data = (await get_attention_matrices(after_request))["attention_data"]
346
 
347
  return {"before_attention": before_data, "after_attention": after_data}
@@ -411,7 +437,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
411
  print(f"Replaced text: '{replaced_text}'")
412
 
413
  # Get the "after" attention data
414
- after_request = AttentionRequest(text=replaced_text, model_name=request.model_name)
 
 
 
 
415
  after_data = (await get_attention_matrices(after_request))["attention_data"]
416
 
417
  return {"before_attention": before_data, "after_attention": after_data}
 
8
 
9
  async def get_attention_comparison_bert(request: ComparisonRequest):
10
  """
11
+ BERT and DistilBERT implementation of attention comparison
12
  """
13
  try:
14
+ model_type = "DistilBERT" if "distilbert" in request.model_name.lower() else "BERT"
15
+ print(f"\n=== USING {model_type} ATTENTION COMPARISON IMPLEMENTATION ===")
16
 
17
  # 1. Get the "before" attention data
18
+ before_attention_request = AttentionRequest(
19
+ text=request.text,
20
+ model_name=request.model_name,
21
+ visualization_method=request.visualization_method
22
+ )
23
  before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
24
 
25
  # 2. Tokenize the text
 
51
 
52
  # HANDLE PUNCTUATION
53
  if is_punctuation:
54
+ print(f"\nUsing {model_type} punctuation replacement approach")
55
 
56
  # Find all occurrences of this punctuation in the original text
57
  punctuation_positions = [pos for pos, char in enumerate(original_text) if char == selected_token]
 
66
  # We'll use a heuristic based on the token's position
67
 
68
  # Count non-special tokens before our selected token
69
+ special_tokens = ["[CLS]", "[SEP]"] # Same special tokens for BERT and DistilBERT
70
  non_special_tokens_before = sum(1 for t in tokens[:request.masked_index]
71
+ if t["text"] not in special_tokens)
72
 
73
  # Select the corresponding punctuation position (or last one if out of range)
74
  punct_idx = min(non_special_tokens_before, len(punctuation_positions) - 1)
 
82
  print(f"Replaced text: '{replaced_text}'")
83
 
84
  # Get the after attention data
85
+ after_attention_request = AttentionRequest(
86
+ text=replaced_text,
87
+ model_name=request.model_name,
88
+ visualization_method=request.visualization_method
89
+ )
90
  after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
91
 
92
  # Return comparison data
93
  return {"before_attention": before_data, "after_attention": after_data}
94
 
95
+ # HANDLE REGULAR WORDS FOR BERT/DistilBERT
96
+ print(f"\nUsing {model_type} word replacement approach")
97
  words = original_text.split()
98
  print(f"Words: {words}")
99
 
 
203
  print(f"Could not determine token position, using simple word replacement")
204
  words = original_text.split()
205
 
206
+ # Adjust for special tokens in BERT/DistilBERT ([CLS])
207
  adjusted_index = max(0, request.masked_index - 1)
208
  word_idx = min(adjusted_index, len(words) - 1)
209
 
 
228
  print(f"Replaced text: '{replaced_text}'")
229
 
230
  # Get the after attention data
231
+ after_attention_request = AttentionRequest(
232
+ text=replaced_text,
233
+ model_name=request.model_name,
234
+ visualization_method=request.visualization_method
235
+ )
236
  after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
237
 
238
  # Return comparison data
239
  return {"before_attention": before_data, "after_attention": after_data}
240
 
241
  except Exception as e:
242
+ print(f"{model_type} Attention comparison error: {str(e)}")
243
  import traceback
244
  traceback.print_exc()
245
  raise HTTPException(status_code=500, detail=str(e))
 
257
  print(f"Replacement word: '{request.replacement_word}'")
258
 
259
  # Get the "before" attention data
260
+ before_attention_request = AttentionRequest(
261
+ text=request.text,
262
+ model_name=request.model_name,
263
+ visualization_method=request.visualization_method
264
+ )
265
  before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
266
 
267
  # Tokenize the text
 
320
  print(f"Replaced text: '{replaced_text}'")
321
 
322
  # Get the "after" attention data and return
323
+ after_request = AttentionRequest(
324
+ text=replaced_text,
325
+ model_name=request.model_name,
326
+ visualization_method=request.visualization_method
327
+ )
328
  after_data = (await get_attention_matrices(after_request))["attention_data"]
329
  return {"before_attention": before_data, "after_attention": after_data}
330
 
 
363
  print(f"Replaced text: '{replaced_text}'")
364
 
365
  # Get the "after" attention data
366
+ after_request = AttentionRequest(
367
+ text=replaced_text,
368
+ model_name=request.model_name,
369
+ visualization_method=request.visualization_method
370
+ )
371
  after_data = (await get_attention_matrices(after_request))["attention_data"]
372
 
373
  return {"before_attention": before_data, "after_attention": after_data}
 
437
  print(f"Replaced text: '{replaced_text}'")
438
 
439
  # Get the "after" attention data
440
+ after_request = AttentionRequest(
441
+ text=replaced_text,
442
+ model_name=request.model_name,
443
+ visualization_method=request.visualization_method
444
+ )
445
  after_data = (await get_attention_matrices(after_request))["attention_data"]
446
 
447
  return {"before_attention": before_data, "after_attention": after_data}
attention_processing.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import networkx as nx
4
+ from typing import List, Dict, Any, Optional, Tuple
5
+
6
+ #############################################
7
+ # Attention Rollout Calculation Function
8
+ #############################################
9
+ def compute_attention_rollout(attentions, add_identity: bool = True, debug: bool = False):
10
+ """
11
+ Compute attention rollout from raw attention matrices
12
+
13
+ Args:
14
+ attentions: List of attention tensors from the model
15
+ add_identity: Whether to add identity matrix to each attention layer
16
+ debug: Whether to print debug information
17
+
18
+ Returns:
19
+ The attention rollout matrix
20
+ """
21
+ num_layers = len(attentions)
22
+ seq_len = attentions[0].size(-1)
23
+ rollout = torch.eye(seq_len)
24
+ for i, att in enumerate(attentions):
25
+ att_avg = att.squeeze(0).mean(dim=0)
26
+ if add_identity:
27
+ att_aug = att_avg + torch.eye(seq_len)
28
+ else:
29
+ att_aug = att_avg
30
+ att_aug = att_aug / (att_aug.sum(dim=-1, keepdim=True) + 1e-8)
31
+ att_aug = torch.nan_to_num(att_aug, nan=0.0, posinf=0.0, neginf=0.0)
32
+ rollout = rollout @ att_aug
33
+ if debug:
34
+ print(f"[DEBUG] Rollout after layer {i+1}:")
35
+ print(att_aug)
36
+ print(rollout)
37
+
38
+ # Normalize rollout to ensure it sums to 1.0 exactly
39
+ rollout_sum = rollout.sum(dim=-1, keepdim=True)
40
+ # Handle zero sums to avoid division by zero
41
+ is_zero_sum = (rollout_sum == 0)
42
+ if is_zero_sum.any():
43
+ # For rows with zero sum, distribute evenly
44
+ seq_len = rollout.size(-1)
45
+ even_dist = torch.ones_like(rollout) / seq_len
46
+ rollout = torch.where(is_zero_sum, even_dist, rollout / rollout_sum)
47
+ else:
48
+ rollout = rollout / rollout_sum
49
+
50
+ return rollout
51
+
52
+ #############################################
53
+ # Capacity Graph Construction Function (for Flow)
54
+ #############################################
55
+ def build_graph(joint_attentions, input_tokens, remove_diag: bool = False, debug: bool = False):
56
+ """
57
+ Build a graph representation for attention flow computation
58
+
59
+ Args:
60
+ joint_attentions: Joint attention matrices
61
+ input_tokens: List of input token text
62
+ remove_diag: Whether to remove diagonal elements (self-attention)
63
+ debug: Whether to print debug information
64
+
65
+ Returns:
66
+ Tuple of (capacity matrix, node labels dictionary)
67
+ """
68
+ n_layers, seq_len, _ = joint_attentions.shape
69
+ total_nodes = (n_layers + 1) * seq_len
70
+ capacity = np.zeros((total_nodes, total_nodes))
71
+ labels = {}
72
+ for k in range(seq_len):
73
+ labels[k] = f"0_{k}_{input_tokens[k]}"
74
+ for i in range(1, n_layers + 1):
75
+ for k_to in range(seq_len):
76
+ node_to = i * seq_len + k_to
77
+ labels[node_to] = f"L{i}_{k_to}"
78
+ for k_from in range(seq_len):
79
+ if remove_diag and (k_from == k_to):
80
+ continue
81
+ node_from = (i - 1) * seq_len + k_from
82
+ cap = joint_attentions[i - 1][k_from][k_to]
83
+ capacity[node_from][node_to] = cap
84
+ if debug:
85
+ print(f"[DEBUG] Edge from {labels[node_from]} to {labels[node_to]} with capacity: {cap:.6f}")
86
+ return capacity, labels
87
+
88
+ #############################################
89
+ # Attention Flow Calculation Function (using networkx)
90
+ #############################################
91
+ def compute_attention_flow_networkx(attentions, add_identity: bool = True, debug: bool = False, mask_idx=None):
92
+ """
93
+ Compute attention flow using networkx max flow algorithm
94
+
95
+ Args:
96
+ attentions: List of attention tensors from the model
97
+ add_identity: Whether to add identity matrix to each attention layer
98
+ debug: Whether to print debug information
99
+ mask_idx: Index of token to compute flow from (if None, computes flow for all tokens)
100
+
101
+ Returns:
102
+ Flow matrix or vector depending on mask_idx
103
+ """
104
+ num_layers = len(attentions)
105
+ seq_len = attentions[0].size(-1)
106
+ joint_attentions = []
107
+ for att in attentions:
108
+ att_avg = att.squeeze(0).mean(dim=0)
109
+ if add_identity:
110
+ alpha = 0.5
111
+ att_aug = (att_avg + torch.eye(seq_len)) * alpha
112
+ else:
113
+ att_aug = att_avg
114
+ att_aug = att_aug / (att_aug.sum(dim=-1, keepdim=True) + 1e-8)
115
+ joint_attentions.append(att_aug.cpu().numpy())
116
+ joint_attentions = np.stack(joint_attentions, axis=0)
117
+ input_tokens = [str(i) for i in range(seq_len)]
118
+ capacity, labels = build_graph(joint_attentions, input_tokens, remove_diag=False, debug=debug)
119
+ total_nodes = capacity.shape[0]
120
+ G = nx.DiGraph()
121
+ for u in range(total_nodes):
122
+ for v in range(total_nodes):
123
+ cap = capacity[u][v]
124
+ if cap > 1e-8:
125
+ G.add_edge(u, v, capacity=float(cap))
126
+ if mask_idx is not None:
127
+ source = mask_idx
128
+ flow_vector = np.zeros(seq_len)
129
+ if debug:
130
+ print(f"[DEBUG] Networkx max flow from {labels[source]} to each output node:")
131
+ for sink in range(num_layers * seq_len, (num_layers + 1) * seq_len):
132
+ try:
133
+ flow_value, _ = nx.maximum_flow(G, source, sink, capacity='capacity')
134
+ except Exception:
135
+ flow_value = 0
136
+ flow_vector[sink - num_layers * seq_len] = flow_value
137
+ flow_vector = flow_vector / (flow_vector.sum() + 1e-8)
138
+ return flow_vector.reshape(1, seq_len)
139
+ else:
140
+ flow_matrix = np.zeros((seq_len, seq_len))
141
+ if debug:
142
+ print("[DEBUG] Networkx max flow for each input node to each output node:")
143
+ for i in range(seq_len):
144
+ source = i
145
+ flow_vector = np.zeros(seq_len)
146
+ for sink in range(num_layers * seq_len, (num_layers + 1) * seq_len):
147
+ try:
148
+ flow_value, _ = nx.maximum_flow(G, source, sink, capacity='capacity')
149
+ except Exception:
150
+ flow_value = 0
151
+ flow_vector[sink - num_layers * seq_len] = flow_value
152
+ # Normalize flow vector to ensure it sums to 1.0 exactly
153
+ flow_sum = flow_vector.sum()
154
+ if flow_sum > 0:
155
+ flow_vector = flow_vector / flow_sum
156
+ else:
157
+ # If there is no flow, distribute evenly
158
+ flow_vector = np.ones(seq_len) / seq_len
159
+ flow_matrix[i] = flow_vector
160
+ if debug:
161
+ print("[DEBUG] Final networkx flow matrix:")
162
+ print(flow_matrix)
163
+ return flow_matrix
164
+
165
+ #############################################
166
+ # Process Attention with Selected Method
167
+ #############################################
168
+ def process_attention_with_method(attention_matrices, method: str = "raw", debug: bool = False) -> List[Dict[str, Any]]:
169
+ """
170
+ Process attention matrices using the specified method
171
+
172
+ Args:
173
+ attention_matrices: List of attention tensors from the model
174
+ method: Method to use (raw, rollout, flow)
175
+ debug: Whether to print debug information
176
+
177
+ Returns:
178
+ Processed attention data in the same format as the original attention data
179
+ """
180
+ if method == "raw":
181
+ # Return raw attention as is
182
+ return attention_matrices
183
+
184
+ # Convert to list if it's a tuple
185
+ if isinstance(attention_matrices, tuple):
186
+ attention_matrices = list(attention_matrices)
187
+
188
+ # Get dimensions
189
+ num_layers = len(attention_matrices)
190
+
191
+ if method == "rollout":
192
+ # Calculate rollout attention
193
+ rollout_matrix = compute_attention_rollout(attention_matrices, add_identity=True, debug=debug)
194
+
195
+ # Create new attention matrices with rollout values
196
+ new_attention_matrices = []
197
+ for layer_idx in range(num_layers):
198
+ # For each layer, we'll use the same rollout matrix for all heads
199
+ heads = attention_matrices[layer_idx].shape[1] # Number of heads
200
+ layer_data = {"layerIndex": layer_idx, "heads": []}
201
+
202
+ for head_idx in range(heads):
203
+ # Convert rollout matrix to list format for this head
204
+ attention_matrix = rollout_matrix.cpu().tolist()
205
+
206
+ layer_data["heads"].append({
207
+ "headIndex": head_idx,
208
+ "attention": attention_matrix
209
+ })
210
+
211
+ new_attention_matrices.append(layer_data)
212
+
213
+ return new_attention_matrices
214
+
215
+ elif method == "flow":
216
+ # Calculate flow attention
217
+ flow_matrix = compute_attention_flow_networkx(attention_matrices, add_identity=True, debug=debug)
218
+
219
+ # Create new attention matrices with flow values
220
+ new_attention_matrices = []
221
+ for layer_idx in range(num_layers):
222
+ # For each layer, we'll use the same flow matrix for all heads
223
+ heads = attention_matrices[layer_idx].shape[1] # Number of heads
224
+ layer_data = {"layerIndex": layer_idx, "heads": []}
225
+
226
+ for head_idx in range(heads):
227
+ # Convert flow matrix to list format for this head
228
+ attention_matrix = flow_matrix.tolist()
229
+
230
+ layer_data["heads"].append({
231
+ "headIndex": head_idx,
232
+ "attention": attention_matrix
233
+ })
234
+
235
+ new_attention_matrices.append(layer_data)
236
+
237
+ return new_attention_matrices
238
+
239
+ else:
240
+ raise ValueError(f"Unknown attention processing method: {method}")
classes.py CHANGED
@@ -32,6 +32,7 @@ class MaskPredictionResponse(BaseModel):
32
  class AttentionRequest(BaseModel):
33
  text: str
34
  model_name: str = "bert-base-uncased"
 
35
 
36
  class AttentionHead(BaseModel):
37
  headIndex: int
@@ -53,6 +54,7 @@ class ComparisonRequest(BaseModel):
53
  masked_index: int
54
  replacement_word: str
55
  model_name: str = "bert-base-uncased"
 
56
 
57
  class AttentionComparisonResponse(BaseModel):
58
  before_attention: AttentionData
 
32
  class AttentionRequest(BaseModel):
33
  text: str
34
  model_name: str = "bert-base-uncased"
35
+ visualization_method: str = "raw" # Options: "raw", "rollout", "flow"
36
 
37
  class AttentionHead(BaseModel):
38
  headIndex: int
 
54
  masked_index: int
55
  replacement_word: str
56
  model_name: str = "bert-base-uncased"
57
+ visualization_method: str = "raw" # Options: "raw", "rollout", "flow"
58
 
59
  class AttentionComparisonResponse(BaseModel):
60
  before_attention: AttentionData
helpers.py CHANGED
@@ -112,6 +112,95 @@ def map_roberta_tokens_to_words(tokens, original_text):
112
 
113
  return token_to_word_map
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Helper function to load models on demand
116
  def get_model_and_tokenizer(model_name):
117
  if model_name not in MODEL_CONFIGS:
 
112
 
113
  return token_to_word_map
114
 
115
+ # Helper function to map BERT and DistilBERT tokens to word positions
116
+ def map_bert_tokens_to_words(tokens, original_text):
117
+ """
118
+ Maps BERT/DistilBERT tokens to words in the original text.
119
+ Returns a dictionary mapping token indices to word indices.
120
+ """
121
+ # Get the words from the original text
122
+ words = original_text.split()
123
+ print(f"Original words: {words}")
124
+
125
+ # Filter out special tokens
126
+ content_tokens = []
127
+ for i, token in enumerate(tokens):
128
+ if token["text"] not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"]:
129
+ content_tokens.append((i, token["text"]))
130
+
131
+ print(f"Content tokens: {[t for _, t in content_tokens]}")
132
+
133
+ # Create the mapping
134
+ token_to_word_map = {}
135
+
136
+ # First approach: direct matching of tokens to words, handling WordPiece tokens
137
+ word_idx = 0
138
+ for token_idx, token_text in content_tokens:
139
+ clean_token = token_text.lower().strip("##")
140
+
141
+ # Check if this is a continuation token (starting with ##)
142
+ if token_text.startswith("##"):
143
+ # If it's a continuation, map it to the same word as the previous token
144
+ if token_idx > 0 and (token_idx - 1) in token_to_word_map:
145
+ token_to_word_map[token_idx] = token_to_word_map[token_idx - 1]
146
+ print(f"Continuation token: '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}'")
147
+ continue
148
+
149
+ # Try to find a match with words
150
+ while word_idx < len(words):
151
+ word_lower = words[word_idx].lower()
152
+ if clean_token in word_lower:
153
+ token_to_word_map[token_idx] = word_idx
154
+ print(f"Match: Token '{token_text}' -> Word '{words[word_idx]}' at index {word_idx}")
155
+ # Only advance to next word if this token is a complete word
156
+ if clean_token == word_lower:
157
+ word_idx += 1
158
+ break
159
+ else:
160
+ word_idx += 1
161
+
162
+ # If we've gone through all words, break
163
+ if word_idx >= len(words):
164
+ break
165
+
166
+ # Second approach: Position-based matching for any remaining tokens
167
+ if len(token_to_word_map) < len(content_tokens):
168
+ print("Using position-based matching for remaining tokens")
169
+
170
+ # Assign unmapped tokens based on surrounding mapped tokens
171
+ for token_idx, token_text in content_tokens:
172
+ if token_idx not in token_to_word_map:
173
+ # Look for the nearest mapped token before this one
174
+ prev_idx = token_idx - 1
175
+ while prev_idx >= 0 and prev_idx not in token_to_word_map:
176
+ prev_idx -= 1
177
+
178
+ # Look for the nearest mapped token after this one
179
+ next_idx = token_idx + 1
180
+ while next_idx < len(tokens) and next_idx not in token_to_word_map:
181
+ next_idx += 1
182
+
183
+ # Assign to the closest mapped word
184
+ if prev_idx >= 0 and prev_idx in token_to_word_map:
185
+ token_to_word_map[token_idx] = token_to_word_map[prev_idx]
186
+ print(f"Position match: Token '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}' (based on previous)")
187
+ elif next_idx < len(tokens) and next_idx in token_to_word_map:
188
+ token_to_word_map[token_idx] = token_to_word_map[next_idx]
189
+ print(f"Position match: Token '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}' (based on next)")
190
+ elif word_idx > 0:
191
+ # Fallback to the last word if no nearby tokens are mapped
192
+ token_to_word_map[token_idx] = min(word_idx - 1, len(words) - 1)
193
+ print(f"Fallback match: Token '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}'")
194
+
195
+ # Print the final mapping
196
+ print("Final token-to-word mapping:")
197
+ for token_idx, word_idx in sorted(token_to_word_map.items()):
198
+ token_text = next((t["text"] for i, t in enumerate(tokens) if i == token_idx), "")
199
+ if word_idx < len(words):
200
+ print(f" Token '{token_text}' (idx {token_idx}) -> Word {word_idx} '{words[word_idx]}'")
201
+
202
+ return token_to_word_map
203
+
204
  # Helper function to load models on demand
205
  def get_model_and_tokenizer(model_name):
206
  if model_name not in MODEL_CONFIGS:
models.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, BertModel, RobertaModel
2
  import nltk
3
 
4
 
@@ -24,6 +24,12 @@ MODEL_CONFIGS = {
24
  "model_class": RobertaForMaskedLM,
25
  "tokenizer_class": AutoTokenizer,
26
  "base_model_class": RobertaModel
 
 
 
 
 
 
27
  }
28
  }
29
 
 
1
+ from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, BertModel, RobertaModel, DistilBertForMaskedLM, DistilBertModel
2
  import nltk
3
 
4
 
 
24
  "model_class": RobertaForMaskedLM,
25
  "tokenizer_class": AutoTokenizer,
26
  "base_model_class": RobertaModel
27
+ },
28
+ "distilbert-base-uncased": {
29
+ "name": "DistilBERT Base Uncased",
30
+ "model_class": DistilBertForMaskedLM,
31
+ "tokenizer_class": AutoTokenizer,
32
+ "base_model_class": DistilBertModel
33
  }
34
  }
35
 
requirements.txt CHANGED
@@ -5,4 +5,5 @@ uvicorn>=0.21.1
5
  pydantic>=2.0.0
6
  python-multipart>=0.0.6
7
  numpy>=1.24.0
8
- nltk>=3.8.1
 
 
5
  pydantic>=2.0.0
6
  python-multipart>=0.0.6
7
  numpy>=1.24.0
8
+ nltk>=3.8.1
9
+ networkx>=3.0
routes/attention.py CHANGED
@@ -2,13 +2,14 @@ from fastapi import APIRouter, HTTPException
2
  from classes import *
3
  from helpers import *
4
  from routes.tokenize import tokenize_text
 
5
  router = APIRouter()
6
 
7
  @router.post("", response_model=AttentionResponse)
8
  async def get_attention_matrices(request: AttentionRequest):
9
  """Get attention matrices for the input text using the specified model"""
10
  try:
11
- print(f"Processing attention request: text='{request.text}', model={request.model_name}")
12
 
13
  # First tokenize the text using the same function that the /tokenize endpoint uses
14
  # to ensure consistency
@@ -29,7 +30,7 @@ async def get_attention_matrices(request: AttentionRequest):
29
  base_model_key = f"{model_name}_base"
30
  if base_model_key not in models:
31
  print(f"Loading base model {model_name}...")
32
- models[base_model_key] = base_model_class.from_pretrained(model_name)
33
  if torch.cuda.is_available():
34
  models[base_model_key] = models[base_model_key].cuda()
35
  models[base_model_key].eval()
@@ -46,9 +47,16 @@ async def get_attention_matrices(request: AttentionRequest):
46
  return_tensors="pt",
47
  return_attention_mask=True
48
  )
 
 
 
49
  else:
 
50
  text = f"[CLS] {request.text} [SEP]"
51
  encoding = tokenizer(text, return_tensors="pt")
 
 
 
52
 
53
  if torch.cuda.is_available():
54
  encoding = {k: v.cuda() for k, v in encoding.items()}
@@ -73,6 +81,10 @@ async def get_attention_matrices(request: AttentionRequest):
73
  attention_matrices = outputs.attentions
74
  print(f"Got attention matrices for {len(attention_matrices)} layers")
75
 
 
 
 
 
76
  # Convert attention matrices to the expected response format
77
  layers = []
78
  for layer_idx, layer_attention in enumerate(attention_matrices):
@@ -96,8 +108,23 @@ async def get_attention_matrices(request: AttentionRequest):
96
  "layerIndex": layer_idx,
97
  "heads": heads
98
  })
99
-
100
  print(f"Processed {len(layers)} layers with {num_heads} heads each")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # Return complete attention data
103
  attention_data = {
 
2
  from classes import *
3
  from helpers import *
4
  from routes.tokenize import tokenize_text
5
+ from attention_processing import process_attention_with_method
6
  router = APIRouter()
7
 
8
  @router.post("", response_model=AttentionResponse)
9
  async def get_attention_matrices(request: AttentionRequest):
10
  """Get attention matrices for the input text using the specified model"""
11
  try:
12
+ print(f"Processing attention request: text='{request.text}', model={request.model_name}, method={request.visualization_method}")
13
 
14
  # First tokenize the text using the same function that the /tokenize endpoint uses
15
  # to ensure consistency
 
30
  base_model_key = f"{model_name}_base"
31
  if base_model_key not in models:
32
  print(f"Loading base model {model_name}...")
33
+ models[base_model_key] = base_model_class.from_pretrained(model_name, attn_implementation="eager")
34
  if torch.cuda.is_available():
35
  models[base_model_key] = models[base_model_key].cuda()
36
  models[base_model_key].eval()
 
47
  return_tensors="pt",
48
  return_attention_mask=True
49
  )
50
+
51
+ # Map RoBERTa tokens to words for better visualization
52
+ token_to_word_map = map_roberta_tokens_to_words(tokens, request.text)
53
  else:
54
+ # For BERT and DistilBERT
55
  text = f"[CLS] {request.text} [SEP]"
56
  encoding = tokenizer(text, return_tensors="pt")
57
+
58
+ # Map BERT/DistilBERT tokens to words for better visualization
59
+ token_to_word_map = map_bert_tokens_to_words(tokens, request.text)
60
 
61
  if torch.cuda.is_available():
62
  encoding = {k: v.cuda() for k, v in encoding.items()}
 
81
  attention_matrices = outputs.attentions
82
  print(f"Got attention matrices for {len(attention_matrices)} layers")
83
 
84
+ # Process attention using the specified method
85
+ if request.visualization_method != "raw":
86
+ print(f"Processing attention with method: {request.visualization_method}")
87
+
88
  # Convert attention matrices to the expected response format
89
  layers = []
90
  for layer_idx, layer_attention in enumerate(attention_matrices):
 
108
  "layerIndex": layer_idx,
109
  "heads": heads
110
  })
111
+
112
  print(f"Processed {len(layers)} layers with {num_heads} heads each")
113
+
114
+ # Process with selected visualization method
115
+ if request.visualization_method != "raw":
116
+ processed_layers = process_attention_with_method(
117
+ attention_matrices,
118
+ method=request.visualization_method,
119
+ debug=False
120
+ )
121
+ # Replace the layers with the processed ones
122
+ layers = processed_layers
123
+
124
+ # Add token-to-word mapping to the response
125
+ for i, token in enumerate(tokens):
126
+ if i in token_to_word_map:
127
+ token["wordIndex"] = token_to_word_map[i]
128
 
129
  # Return complete attention data
130
  attention_data = {
routes/attention_comparison.py CHANGED
@@ -16,11 +16,14 @@ async def get_attention_comparison(request: ComparisonRequest):
16
  print(f"Masked index: {request.masked_index}")
17
  print(f"Replacement word: '{request.replacement_word}'")
18
  print(f"Model: {request.model_name}")
 
19
 
20
  # Dispatch based on model type
21
  if "roberta" in request.model_name.lower():
22
  return await get_attention_comparison_roberta(request)
23
  else:
 
 
24
  return await get_attention_comparison_bert(request)
25
 
26
 
 
16
  print(f"Masked index: {request.masked_index}")
17
  print(f"Replacement word: '{request.replacement_word}'")
18
  print(f"Model: {request.model_name}")
19
+ print(f"Visualization method: {request.visualization_method}")
20
 
21
  # Dispatch based on model type
22
  if "roberta" in request.model_name.lower():
23
  return await get_attention_comparison_roberta(request)
24
  else:
25
+ # Both BERT and DistilBERT use the same tokenization approach (WordPiece)
26
+ # and can use the same comparison implementation
27
  return await get_attention_comparison_bert(request)
28
 
29
 
routes/tokenize.py CHANGED
@@ -27,7 +27,7 @@ async def tokenize_text(request: TokenizeRequest):
27
  # Clean the tokens to remove the leading 'Ġ' character from RoBERTa tokens
28
  tokens = [clean_roberta_token(token) for token in tokens]
29
  else:
30
- # For BERT, add special tokens and tokenize
31
  text = f"[CLS] {request.text} [SEP]"
32
  tokens = tokenizer.tokenize(text)
33
 
 
27
  # Clean the tokens to remove the leading 'Ġ' character from RoBERTa tokens
28
  tokens = [clean_roberta_token(token) for token in tokens]
29
  else:
30
+ # For BERT and DistilBERT, add special tokens and tokenize
31
  text = f"[CLS] {request.text} [SEP]"
32
  tokens = tokenizer.tokenize(text)
33