Upload folder using huggingface_hub
Browse files- attention_comparison_helpers.py +45 -15
- attention_processing.py +240 -0
- classes.py +2 -0
- helpers.py +89 -0
- models.py +7 -1
- requirements.txt +2 -1
- routes/attention.py +30 -3
- routes/attention_comparison.py +3 -0
- routes/tokenize.py +1 -1
attention_comparison_helpers.py
CHANGED
|
@@ -8,13 +8,18 @@ from routes.tokenize import tokenize_text
|
|
| 8 |
|
| 9 |
async def get_attention_comparison_bert(request: ComparisonRequest):
|
| 10 |
"""
|
| 11 |
-
BERT
|
| 12 |
"""
|
| 13 |
try:
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
# 1. Get the "before" attention data
|
| 17 |
-
before_attention_request = AttentionRequest(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
|
| 19 |
|
| 20 |
# 2. Tokenize the text
|
|
@@ -46,7 +51,7 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
|
|
| 46 |
|
| 47 |
# HANDLE PUNCTUATION
|
| 48 |
if is_punctuation:
|
| 49 |
-
print(f"\nUsing
|
| 50 |
|
| 51 |
# Find all occurrences of this punctuation in the original text
|
| 52 |
punctuation_positions = [pos for pos, char in enumerate(original_text) if char == selected_token]
|
|
@@ -61,8 +66,9 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
|
|
| 61 |
# We'll use a heuristic based on the token's position
|
| 62 |
|
| 63 |
# Count non-special tokens before our selected token
|
|
|
|
| 64 |
non_special_tokens_before = sum(1 for t in tokens[:request.masked_index]
|
| 65 |
-
if t["text"] not in
|
| 66 |
|
| 67 |
# Select the corresponding punctuation position (or last one if out of range)
|
| 68 |
punct_idx = min(non_special_tokens_before, len(punctuation_positions) - 1)
|
|
@@ -76,14 +82,18 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
|
|
| 76 |
print(f"Replaced text: '{replaced_text}'")
|
| 77 |
|
| 78 |
# Get the after attention data
|
| 79 |
-
after_attention_request = AttentionRequest(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
|
| 81 |
|
| 82 |
# Return comparison data
|
| 83 |
return {"before_attention": before_data, "after_attention": after_data}
|
| 84 |
|
| 85 |
-
# HANDLE REGULAR WORDS FOR BERT
|
| 86 |
-
print(f"\nUsing
|
| 87 |
words = original_text.split()
|
| 88 |
print(f"Words: {words}")
|
| 89 |
|
|
@@ -193,7 +203,7 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
|
|
| 193 |
print(f"Could not determine token position, using simple word replacement")
|
| 194 |
words = original_text.split()
|
| 195 |
|
| 196 |
-
# Adjust for special tokens in BERT ([CLS])
|
| 197 |
adjusted_index = max(0, request.masked_index - 1)
|
| 198 |
word_idx = min(adjusted_index, len(words) - 1)
|
| 199 |
|
|
@@ -218,14 +228,18 @@ async def get_attention_comparison_bert(request: ComparisonRequest):
|
|
| 218 |
print(f"Replaced text: '{replaced_text}'")
|
| 219 |
|
| 220 |
# Get the after attention data
|
| 221 |
-
after_attention_request = AttentionRequest(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
|
| 223 |
|
| 224 |
# Return comparison data
|
| 225 |
return {"before_attention": before_data, "after_attention": after_data}
|
| 226 |
|
| 227 |
except Exception as e:
|
| 228 |
-
print(f"
|
| 229 |
import traceback
|
| 230 |
traceback.print_exc()
|
| 231 |
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -243,7 +257,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
|
|
| 243 |
print(f"Replacement word: '{request.replacement_word}'")
|
| 244 |
|
| 245 |
# Get the "before" attention data
|
| 246 |
-
before_attention_request = AttentionRequest(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
|
| 248 |
|
| 249 |
# Tokenize the text
|
|
@@ -302,7 +320,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
|
|
| 302 |
print(f"Replaced text: '{replaced_text}'")
|
| 303 |
|
| 304 |
# Get the "after" attention data and return
|
| 305 |
-
after_request = AttentionRequest(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
after_data = (await get_attention_matrices(after_request))["attention_data"]
|
| 307 |
return {"before_attention": before_data, "after_attention": after_data}
|
| 308 |
|
|
@@ -341,7 +363,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
|
|
| 341 |
print(f"Replaced text: '{replaced_text}'")
|
| 342 |
|
| 343 |
# Get the "after" attention data
|
| 344 |
-
after_request = AttentionRequest(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
after_data = (await get_attention_matrices(after_request))["attention_data"]
|
| 346 |
|
| 347 |
return {"before_attention": before_data, "after_attention": after_data}
|
|
@@ -411,7 +437,11 @@ async def get_attention_comparison_roberta(request: ComparisonRequest):
|
|
| 411 |
print(f"Replaced text: '{replaced_text}'")
|
| 412 |
|
| 413 |
# Get the "after" attention data
|
| 414 |
-
after_request = AttentionRequest(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
after_data = (await get_attention_matrices(after_request))["attention_data"]
|
| 416 |
|
| 417 |
return {"before_attention": before_data, "after_attention": after_data}
|
|
|
|
| 8 |
|
| 9 |
async def get_attention_comparison_bert(request: ComparisonRequest):
|
| 10 |
"""
|
| 11 |
+
BERT and DistilBERT implementation of attention comparison
|
| 12 |
"""
|
| 13 |
try:
|
| 14 |
+
model_type = "DistilBERT" if "distilbert" in request.model_name.lower() else "BERT"
|
| 15 |
+
print(f"\n=== USING {model_type} ATTENTION COMPARISON IMPLEMENTATION ===")
|
| 16 |
|
| 17 |
# 1. Get the "before" attention data
|
| 18 |
+
before_attention_request = AttentionRequest(
|
| 19 |
+
text=request.text,
|
| 20 |
+
model_name=request.model_name,
|
| 21 |
+
visualization_method=request.visualization_method
|
| 22 |
+
)
|
| 23 |
before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
|
| 24 |
|
| 25 |
# 2. Tokenize the text
|
|
|
|
| 51 |
|
| 52 |
# HANDLE PUNCTUATION
|
| 53 |
if is_punctuation:
|
| 54 |
+
print(f"\nUsing {model_type} punctuation replacement approach")
|
| 55 |
|
| 56 |
# Find all occurrences of this punctuation in the original text
|
| 57 |
punctuation_positions = [pos for pos, char in enumerate(original_text) if char == selected_token]
|
|
|
|
| 66 |
# We'll use a heuristic based on the token's position
|
| 67 |
|
| 68 |
# Count non-special tokens before our selected token
|
| 69 |
+
special_tokens = ["[CLS]", "[SEP]"] # Same special tokens for BERT and DistilBERT
|
| 70 |
non_special_tokens_before = sum(1 for t in tokens[:request.masked_index]
|
| 71 |
+
if t["text"] not in special_tokens)
|
| 72 |
|
| 73 |
# Select the corresponding punctuation position (or last one if out of range)
|
| 74 |
punct_idx = min(non_special_tokens_before, len(punctuation_positions) - 1)
|
|
|
|
| 82 |
print(f"Replaced text: '{replaced_text}'")
|
| 83 |
|
| 84 |
# Get the after attention data
|
| 85 |
+
after_attention_request = AttentionRequest(
|
| 86 |
+
text=replaced_text,
|
| 87 |
+
model_name=request.model_name,
|
| 88 |
+
visualization_method=request.visualization_method
|
| 89 |
+
)
|
| 90 |
after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
|
| 91 |
|
| 92 |
# Return comparison data
|
| 93 |
return {"before_attention": before_data, "after_attention": after_data}
|
| 94 |
|
| 95 |
+
# HANDLE REGULAR WORDS FOR BERT/DistilBERT
|
| 96 |
+
print(f"\nUsing {model_type} word replacement approach")
|
| 97 |
words = original_text.split()
|
| 98 |
print(f"Words: {words}")
|
| 99 |
|
|
|
|
| 203 |
print(f"Could not determine token position, using simple word replacement")
|
| 204 |
words = original_text.split()
|
| 205 |
|
| 206 |
+
# Adjust for special tokens in BERT/DistilBERT ([CLS])
|
| 207 |
adjusted_index = max(0, request.masked_index - 1)
|
| 208 |
word_idx = min(adjusted_index, len(words) - 1)
|
| 209 |
|
|
|
|
| 228 |
print(f"Replaced text: '{replaced_text}'")
|
| 229 |
|
| 230 |
# Get the after attention data
|
| 231 |
+
after_attention_request = AttentionRequest(
|
| 232 |
+
text=replaced_text,
|
| 233 |
+
model_name=request.model_name,
|
| 234 |
+
visualization_method=request.visualization_method
|
| 235 |
+
)
|
| 236 |
after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
|
| 237 |
|
| 238 |
# Return comparison data
|
| 239 |
return {"before_attention": before_data, "after_attention": after_data}
|
| 240 |
|
| 241 |
except Exception as e:
|
| 242 |
+
print(f"{model_type} Attention comparison error: {str(e)}")
|
| 243 |
import traceback
|
| 244 |
traceback.print_exc()
|
| 245 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 257 |
print(f"Replacement word: '{request.replacement_word}'")
|
| 258 |
|
| 259 |
# Get the "before" attention data
|
| 260 |
+
before_attention_request = AttentionRequest(
|
| 261 |
+
text=request.text,
|
| 262 |
+
model_name=request.model_name,
|
| 263 |
+
visualization_method=request.visualization_method
|
| 264 |
+
)
|
| 265 |
before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
|
| 266 |
|
| 267 |
# Tokenize the text
|
|
|
|
| 320 |
print(f"Replaced text: '{replaced_text}'")
|
| 321 |
|
| 322 |
# Get the "after" attention data and return
|
| 323 |
+
after_request = AttentionRequest(
|
| 324 |
+
text=replaced_text,
|
| 325 |
+
model_name=request.model_name,
|
| 326 |
+
visualization_method=request.visualization_method
|
| 327 |
+
)
|
| 328 |
after_data = (await get_attention_matrices(after_request))["attention_data"]
|
| 329 |
return {"before_attention": before_data, "after_attention": after_data}
|
| 330 |
|
|
|
|
| 363 |
print(f"Replaced text: '{replaced_text}'")
|
| 364 |
|
| 365 |
# Get the "after" attention data
|
| 366 |
+
after_request = AttentionRequest(
|
| 367 |
+
text=replaced_text,
|
| 368 |
+
model_name=request.model_name,
|
| 369 |
+
visualization_method=request.visualization_method
|
| 370 |
+
)
|
| 371 |
after_data = (await get_attention_matrices(after_request))["attention_data"]
|
| 372 |
|
| 373 |
return {"before_attention": before_data, "after_attention": after_data}
|
|
|
|
| 437 |
print(f"Replaced text: '{replaced_text}'")
|
| 438 |
|
| 439 |
# Get the "after" attention data
|
| 440 |
+
after_request = AttentionRequest(
|
| 441 |
+
text=replaced_text,
|
| 442 |
+
model_name=request.model_name,
|
| 443 |
+
visualization_method=request.visualization_method
|
| 444 |
+
)
|
| 445 |
after_data = (await get_attention_matrices(after_request))["attention_data"]
|
| 446 |
|
| 447 |
return {"before_attention": before_data, "after_attention": after_data}
|
attention_processing.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import networkx as nx
|
| 4 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
#############################################
|
| 7 |
+
# Attention Rollout Calculation Function
|
| 8 |
+
#############################################
|
| 9 |
+
def compute_attention_rollout(attentions, add_identity: bool = True, debug: bool = False):
|
| 10 |
+
"""
|
| 11 |
+
Compute attention rollout from raw attention matrices
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
attentions: List of attention tensors from the model
|
| 15 |
+
add_identity: Whether to add identity matrix to each attention layer
|
| 16 |
+
debug: Whether to print debug information
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
The attention rollout matrix
|
| 20 |
+
"""
|
| 21 |
+
num_layers = len(attentions)
|
| 22 |
+
seq_len = attentions[0].size(-1)
|
| 23 |
+
rollout = torch.eye(seq_len)
|
| 24 |
+
for i, att in enumerate(attentions):
|
| 25 |
+
att_avg = att.squeeze(0).mean(dim=0)
|
| 26 |
+
if add_identity:
|
| 27 |
+
att_aug = att_avg + torch.eye(seq_len)
|
| 28 |
+
else:
|
| 29 |
+
att_aug = att_avg
|
| 30 |
+
att_aug = att_aug / (att_aug.sum(dim=-1, keepdim=True) + 1e-8)
|
| 31 |
+
att_aug = torch.nan_to_num(att_aug, nan=0.0, posinf=0.0, neginf=0.0)
|
| 32 |
+
rollout = rollout @ att_aug
|
| 33 |
+
if debug:
|
| 34 |
+
print(f"[DEBUG] Rollout after layer {i+1}:")
|
| 35 |
+
print(att_aug)
|
| 36 |
+
print(rollout)
|
| 37 |
+
|
| 38 |
+
# Normalize rollout to ensure it sums to 1.0 exactly
|
| 39 |
+
rollout_sum = rollout.sum(dim=-1, keepdim=True)
|
| 40 |
+
# Handle zero sums to avoid division by zero
|
| 41 |
+
is_zero_sum = (rollout_sum == 0)
|
| 42 |
+
if is_zero_sum.any():
|
| 43 |
+
# For rows with zero sum, distribute evenly
|
| 44 |
+
seq_len = rollout.size(-1)
|
| 45 |
+
even_dist = torch.ones_like(rollout) / seq_len
|
| 46 |
+
rollout = torch.where(is_zero_sum, even_dist, rollout / rollout_sum)
|
| 47 |
+
else:
|
| 48 |
+
rollout = rollout / rollout_sum
|
| 49 |
+
|
| 50 |
+
return rollout
|
| 51 |
+
|
| 52 |
+
#############################################
|
| 53 |
+
# Capacity Graph Construction Function (for Flow)
|
| 54 |
+
#############################################
|
| 55 |
+
def build_graph(joint_attentions, input_tokens, remove_diag: bool = False, debug: bool = False):
|
| 56 |
+
"""
|
| 57 |
+
Build a graph representation for attention flow computation
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
joint_attentions: Joint attention matrices
|
| 61 |
+
input_tokens: List of input token text
|
| 62 |
+
remove_diag: Whether to remove diagonal elements (self-attention)
|
| 63 |
+
debug: Whether to print debug information
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple of (capacity matrix, node labels dictionary)
|
| 67 |
+
"""
|
| 68 |
+
n_layers, seq_len, _ = joint_attentions.shape
|
| 69 |
+
total_nodes = (n_layers + 1) * seq_len
|
| 70 |
+
capacity = np.zeros((total_nodes, total_nodes))
|
| 71 |
+
labels = {}
|
| 72 |
+
for k in range(seq_len):
|
| 73 |
+
labels[k] = f"0_{k}_{input_tokens[k]}"
|
| 74 |
+
for i in range(1, n_layers + 1):
|
| 75 |
+
for k_to in range(seq_len):
|
| 76 |
+
node_to = i * seq_len + k_to
|
| 77 |
+
labels[node_to] = f"L{i}_{k_to}"
|
| 78 |
+
for k_from in range(seq_len):
|
| 79 |
+
if remove_diag and (k_from == k_to):
|
| 80 |
+
continue
|
| 81 |
+
node_from = (i - 1) * seq_len + k_from
|
| 82 |
+
cap = joint_attentions[i - 1][k_from][k_to]
|
| 83 |
+
capacity[node_from][node_to] = cap
|
| 84 |
+
if debug:
|
| 85 |
+
print(f"[DEBUG] Edge from {labels[node_from]} to {labels[node_to]} with capacity: {cap:.6f}")
|
| 86 |
+
return capacity, labels
|
| 87 |
+
|
| 88 |
+
#############################################
|
| 89 |
+
# Attention Flow Calculation Function (using networkx)
|
| 90 |
+
#############################################
|
| 91 |
+
def compute_attention_flow_networkx(attentions, add_identity: bool = True, debug: bool = False, mask_idx=None):
|
| 92 |
+
"""
|
| 93 |
+
Compute attention flow using networkx max flow algorithm
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
attentions: List of attention tensors from the model
|
| 97 |
+
add_identity: Whether to add identity matrix to each attention layer
|
| 98 |
+
debug: Whether to print debug information
|
| 99 |
+
mask_idx: Index of token to compute flow from (if None, computes flow for all tokens)
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Flow matrix or vector depending on mask_idx
|
| 103 |
+
"""
|
| 104 |
+
num_layers = len(attentions)
|
| 105 |
+
seq_len = attentions[0].size(-1)
|
| 106 |
+
joint_attentions = []
|
| 107 |
+
for att in attentions:
|
| 108 |
+
att_avg = att.squeeze(0).mean(dim=0)
|
| 109 |
+
if add_identity:
|
| 110 |
+
alpha = 0.5
|
| 111 |
+
att_aug = (att_avg + torch.eye(seq_len)) * alpha
|
| 112 |
+
else:
|
| 113 |
+
att_aug = att_avg
|
| 114 |
+
att_aug = att_aug / (att_aug.sum(dim=-1, keepdim=True) + 1e-8)
|
| 115 |
+
joint_attentions.append(att_aug.cpu().numpy())
|
| 116 |
+
joint_attentions = np.stack(joint_attentions, axis=0)
|
| 117 |
+
input_tokens = [str(i) for i in range(seq_len)]
|
| 118 |
+
capacity, labels = build_graph(joint_attentions, input_tokens, remove_diag=False, debug=debug)
|
| 119 |
+
total_nodes = capacity.shape[0]
|
| 120 |
+
G = nx.DiGraph()
|
| 121 |
+
for u in range(total_nodes):
|
| 122 |
+
for v in range(total_nodes):
|
| 123 |
+
cap = capacity[u][v]
|
| 124 |
+
if cap > 1e-8:
|
| 125 |
+
G.add_edge(u, v, capacity=float(cap))
|
| 126 |
+
if mask_idx is not None:
|
| 127 |
+
source = mask_idx
|
| 128 |
+
flow_vector = np.zeros(seq_len)
|
| 129 |
+
if debug:
|
| 130 |
+
print(f"[DEBUG] Networkx max flow from {labels[source]} to each output node:")
|
| 131 |
+
for sink in range(num_layers * seq_len, (num_layers + 1) * seq_len):
|
| 132 |
+
try:
|
| 133 |
+
flow_value, _ = nx.maximum_flow(G, source, sink, capacity='capacity')
|
| 134 |
+
except Exception:
|
| 135 |
+
flow_value = 0
|
| 136 |
+
flow_vector[sink - num_layers * seq_len] = flow_value
|
| 137 |
+
flow_vector = flow_vector / (flow_vector.sum() + 1e-8)
|
| 138 |
+
return flow_vector.reshape(1, seq_len)
|
| 139 |
+
else:
|
| 140 |
+
flow_matrix = np.zeros((seq_len, seq_len))
|
| 141 |
+
if debug:
|
| 142 |
+
print("[DEBUG] Networkx max flow for each input node to each output node:")
|
| 143 |
+
for i in range(seq_len):
|
| 144 |
+
source = i
|
| 145 |
+
flow_vector = np.zeros(seq_len)
|
| 146 |
+
for sink in range(num_layers * seq_len, (num_layers + 1) * seq_len):
|
| 147 |
+
try:
|
| 148 |
+
flow_value, _ = nx.maximum_flow(G, source, sink, capacity='capacity')
|
| 149 |
+
except Exception:
|
| 150 |
+
flow_value = 0
|
| 151 |
+
flow_vector[sink - num_layers * seq_len] = flow_value
|
| 152 |
+
# Normalize flow vector to ensure it sums to 1.0 exactly
|
| 153 |
+
flow_sum = flow_vector.sum()
|
| 154 |
+
if flow_sum > 0:
|
| 155 |
+
flow_vector = flow_vector / flow_sum
|
| 156 |
+
else:
|
| 157 |
+
# If there is no flow, distribute evenly
|
| 158 |
+
flow_vector = np.ones(seq_len) / seq_len
|
| 159 |
+
flow_matrix[i] = flow_vector
|
| 160 |
+
if debug:
|
| 161 |
+
print("[DEBUG] Final networkx flow matrix:")
|
| 162 |
+
print(flow_matrix)
|
| 163 |
+
return flow_matrix
|
| 164 |
+
|
| 165 |
+
#############################################
|
| 166 |
+
# Process Attention with Selected Method
|
| 167 |
+
#############################################
|
| 168 |
+
def process_attention_with_method(attention_matrices, method: str = "raw", debug: bool = False) -> List[Dict[str, Any]]:
|
| 169 |
+
"""
|
| 170 |
+
Process attention matrices using the specified method
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
attention_matrices: List of attention tensors from the model
|
| 174 |
+
method: Method to use (raw, rollout, flow)
|
| 175 |
+
debug: Whether to print debug information
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Processed attention data in the same format as the original attention data
|
| 179 |
+
"""
|
| 180 |
+
if method == "raw":
|
| 181 |
+
# Return raw attention as is
|
| 182 |
+
return attention_matrices
|
| 183 |
+
|
| 184 |
+
# Convert to list if it's a tuple
|
| 185 |
+
if isinstance(attention_matrices, tuple):
|
| 186 |
+
attention_matrices = list(attention_matrices)
|
| 187 |
+
|
| 188 |
+
# Get dimensions
|
| 189 |
+
num_layers = len(attention_matrices)
|
| 190 |
+
|
| 191 |
+
if method == "rollout":
|
| 192 |
+
# Calculate rollout attention
|
| 193 |
+
rollout_matrix = compute_attention_rollout(attention_matrices, add_identity=True, debug=debug)
|
| 194 |
+
|
| 195 |
+
# Create new attention matrices with rollout values
|
| 196 |
+
new_attention_matrices = []
|
| 197 |
+
for layer_idx in range(num_layers):
|
| 198 |
+
# For each layer, we'll use the same rollout matrix for all heads
|
| 199 |
+
heads = attention_matrices[layer_idx].shape[1] # Number of heads
|
| 200 |
+
layer_data = {"layerIndex": layer_idx, "heads": []}
|
| 201 |
+
|
| 202 |
+
for head_idx in range(heads):
|
| 203 |
+
# Convert rollout matrix to list format for this head
|
| 204 |
+
attention_matrix = rollout_matrix.cpu().tolist()
|
| 205 |
+
|
| 206 |
+
layer_data["heads"].append({
|
| 207 |
+
"headIndex": head_idx,
|
| 208 |
+
"attention": attention_matrix
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
new_attention_matrices.append(layer_data)
|
| 212 |
+
|
| 213 |
+
return new_attention_matrices
|
| 214 |
+
|
| 215 |
+
elif method == "flow":
|
| 216 |
+
# Calculate flow attention
|
| 217 |
+
flow_matrix = compute_attention_flow_networkx(attention_matrices, add_identity=True, debug=debug)
|
| 218 |
+
|
| 219 |
+
# Create new attention matrices with flow values
|
| 220 |
+
new_attention_matrices = []
|
| 221 |
+
for layer_idx in range(num_layers):
|
| 222 |
+
# For each layer, we'll use the same flow matrix for all heads
|
| 223 |
+
heads = attention_matrices[layer_idx].shape[1] # Number of heads
|
| 224 |
+
layer_data = {"layerIndex": layer_idx, "heads": []}
|
| 225 |
+
|
| 226 |
+
for head_idx in range(heads):
|
| 227 |
+
# Convert flow matrix to list format for this head
|
| 228 |
+
attention_matrix = flow_matrix.tolist()
|
| 229 |
+
|
| 230 |
+
layer_data["heads"].append({
|
| 231 |
+
"headIndex": head_idx,
|
| 232 |
+
"attention": attention_matrix
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
new_attention_matrices.append(layer_data)
|
| 236 |
+
|
| 237 |
+
return new_attention_matrices
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError(f"Unknown attention processing method: {method}")
|
classes.py
CHANGED
|
@@ -32,6 +32,7 @@ class MaskPredictionResponse(BaseModel):
|
|
| 32 |
class AttentionRequest(BaseModel):
|
| 33 |
text: str
|
| 34 |
model_name: str = "bert-base-uncased"
|
|
|
|
| 35 |
|
| 36 |
class AttentionHead(BaseModel):
|
| 37 |
headIndex: int
|
|
@@ -53,6 +54,7 @@ class ComparisonRequest(BaseModel):
|
|
| 53 |
masked_index: int
|
| 54 |
replacement_word: str
|
| 55 |
model_name: str = "bert-base-uncased"
|
|
|
|
| 56 |
|
| 57 |
class AttentionComparisonResponse(BaseModel):
|
| 58 |
before_attention: AttentionData
|
|
|
|
| 32 |
class AttentionRequest(BaseModel):
|
| 33 |
text: str
|
| 34 |
model_name: str = "bert-base-uncased"
|
| 35 |
+
visualization_method: str = "raw" # Options: "raw", "rollout", "flow"
|
| 36 |
|
| 37 |
class AttentionHead(BaseModel):
|
| 38 |
headIndex: int
|
|
|
|
| 54 |
masked_index: int
|
| 55 |
replacement_word: str
|
| 56 |
model_name: str = "bert-base-uncased"
|
| 57 |
+
visualization_method: str = "raw" # Options: "raw", "rollout", "flow"
|
| 58 |
|
| 59 |
class AttentionComparisonResponse(BaseModel):
|
| 60 |
before_attention: AttentionData
|
helpers.py
CHANGED
|
@@ -112,6 +112,95 @@ def map_roberta_tokens_to_words(tokens, original_text):
|
|
| 112 |
|
| 113 |
return token_to_word_map
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
# Helper function to load models on demand
|
| 116 |
def get_model_and_tokenizer(model_name):
|
| 117 |
if model_name not in MODEL_CONFIGS:
|
|
|
|
| 112 |
|
| 113 |
return token_to_word_map
|
| 114 |
|
| 115 |
+
# Helper function to map BERT and DistilBERT tokens to word positions
|
| 116 |
+
def map_bert_tokens_to_words(tokens, original_text):
|
| 117 |
+
"""
|
| 118 |
+
Maps BERT/DistilBERT tokens to words in the original text.
|
| 119 |
+
Returns a dictionary mapping token indices to word indices.
|
| 120 |
+
"""
|
| 121 |
+
# Get the words from the original text
|
| 122 |
+
words = original_text.split()
|
| 123 |
+
print(f"Original words: {words}")
|
| 124 |
+
|
| 125 |
+
# Filter out special tokens
|
| 126 |
+
content_tokens = []
|
| 127 |
+
for i, token in enumerate(tokens):
|
| 128 |
+
if token["text"] not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"]:
|
| 129 |
+
content_tokens.append((i, token["text"]))
|
| 130 |
+
|
| 131 |
+
print(f"Content tokens: {[t for _, t in content_tokens]}")
|
| 132 |
+
|
| 133 |
+
# Create the mapping
|
| 134 |
+
token_to_word_map = {}
|
| 135 |
+
|
| 136 |
+
# First approach: direct matching of tokens to words, handling WordPiece tokens
|
| 137 |
+
word_idx = 0
|
| 138 |
+
for token_idx, token_text in content_tokens:
|
| 139 |
+
clean_token = token_text.lower().strip("##")
|
| 140 |
+
|
| 141 |
+
# Check if this is a continuation token (starting with ##)
|
| 142 |
+
if token_text.startswith("##"):
|
| 143 |
+
# If it's a continuation, map it to the same word as the previous token
|
| 144 |
+
if token_idx > 0 and (token_idx - 1) in token_to_word_map:
|
| 145 |
+
token_to_word_map[token_idx] = token_to_word_map[token_idx - 1]
|
| 146 |
+
print(f"Continuation token: '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}'")
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# Try to find a match with words
|
| 150 |
+
while word_idx < len(words):
|
| 151 |
+
word_lower = words[word_idx].lower()
|
| 152 |
+
if clean_token in word_lower:
|
| 153 |
+
token_to_word_map[token_idx] = word_idx
|
| 154 |
+
print(f"Match: Token '{token_text}' -> Word '{words[word_idx]}' at index {word_idx}")
|
| 155 |
+
# Only advance to next word if this token is a complete word
|
| 156 |
+
if clean_token == word_lower:
|
| 157 |
+
word_idx += 1
|
| 158 |
+
break
|
| 159 |
+
else:
|
| 160 |
+
word_idx += 1
|
| 161 |
+
|
| 162 |
+
# If we've gone through all words, break
|
| 163 |
+
if word_idx >= len(words):
|
| 164 |
+
break
|
| 165 |
+
|
| 166 |
+
# Second approach: Position-based matching for any remaining tokens
|
| 167 |
+
if len(token_to_word_map) < len(content_tokens):
|
| 168 |
+
print("Using position-based matching for remaining tokens")
|
| 169 |
+
|
| 170 |
+
# Assign unmapped tokens based on surrounding mapped tokens
|
| 171 |
+
for token_idx, token_text in content_tokens:
|
| 172 |
+
if token_idx not in token_to_word_map:
|
| 173 |
+
# Look for the nearest mapped token before this one
|
| 174 |
+
prev_idx = token_idx - 1
|
| 175 |
+
while prev_idx >= 0 and prev_idx not in token_to_word_map:
|
| 176 |
+
prev_idx -= 1
|
| 177 |
+
|
| 178 |
+
# Look for the nearest mapped token after this one
|
| 179 |
+
next_idx = token_idx + 1
|
| 180 |
+
while next_idx < len(tokens) and next_idx not in token_to_word_map:
|
| 181 |
+
next_idx += 1
|
| 182 |
+
|
| 183 |
+
# Assign to the closest mapped word
|
| 184 |
+
if prev_idx >= 0 and prev_idx in token_to_word_map:
|
| 185 |
+
token_to_word_map[token_idx] = token_to_word_map[prev_idx]
|
| 186 |
+
print(f"Position match: Token '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}' (based on previous)")
|
| 187 |
+
elif next_idx < len(tokens) and next_idx in token_to_word_map:
|
| 188 |
+
token_to_word_map[token_idx] = token_to_word_map[next_idx]
|
| 189 |
+
print(f"Position match: Token '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}' (based on next)")
|
| 190 |
+
elif word_idx > 0:
|
| 191 |
+
# Fallback to the last word if no nearby tokens are mapped
|
| 192 |
+
token_to_word_map[token_idx] = min(word_idx - 1, len(words) - 1)
|
| 193 |
+
print(f"Fallback match: Token '{token_text}' -> Word '{words[token_to_word_map[token_idx]]}'")
|
| 194 |
+
|
| 195 |
+
# Print the final mapping
|
| 196 |
+
print("Final token-to-word mapping:")
|
| 197 |
+
for token_idx, word_idx in sorted(token_to_word_map.items()):
|
| 198 |
+
token_text = next((t["text"] for i, t in enumerate(tokens) if i == token_idx), "")
|
| 199 |
+
if word_idx < len(words):
|
| 200 |
+
print(f" Token '{token_text}' (idx {token_idx}) -> Word {word_idx} '{words[word_idx]}'")
|
| 201 |
+
|
| 202 |
+
return token_to_word_map
|
| 203 |
+
|
| 204 |
# Helper function to load models on demand
|
| 205 |
def get_model_and_tokenizer(model_name):
|
| 206 |
if model_name not in MODEL_CONFIGS:
|
models.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, BertModel, RobertaModel
|
| 2 |
import nltk
|
| 3 |
|
| 4 |
|
|
@@ -24,6 +24,12 @@ MODEL_CONFIGS = {
|
|
| 24 |
"model_class": RobertaForMaskedLM,
|
| 25 |
"tokenizer_class": AutoTokenizer,
|
| 26 |
"base_model_class": RobertaModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
}
|
| 28 |
}
|
| 29 |
|
|
|
|
| 1 |
+
from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, BertModel, RobertaModel, DistilBertForMaskedLM, DistilBertModel
|
| 2 |
import nltk
|
| 3 |
|
| 4 |
|
|
|
|
| 24 |
"model_class": RobertaForMaskedLM,
|
| 25 |
"tokenizer_class": AutoTokenizer,
|
| 26 |
"base_model_class": RobertaModel
|
| 27 |
+
},
|
| 28 |
+
"distilbert-base-uncased": {
|
| 29 |
+
"name": "DistilBERT Base Uncased",
|
| 30 |
+
"model_class": DistilBertForMaskedLM,
|
| 31 |
+
"tokenizer_class": AutoTokenizer,
|
| 32 |
+
"base_model_class": DistilBertModel
|
| 33 |
}
|
| 34 |
}
|
| 35 |
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ uvicorn>=0.21.1
|
|
| 5 |
pydantic>=2.0.0
|
| 6 |
python-multipart>=0.0.6
|
| 7 |
numpy>=1.24.0
|
| 8 |
-
nltk>=3.8.1
|
|
|
|
|
|
| 5 |
pydantic>=2.0.0
|
| 6 |
python-multipart>=0.0.6
|
| 7 |
numpy>=1.24.0
|
| 8 |
+
nltk>=3.8.1
|
| 9 |
+
networkx>=3.0
|
routes/attention.py
CHANGED
|
@@ -2,13 +2,14 @@ from fastapi import APIRouter, HTTPException
|
|
| 2 |
from classes import *
|
| 3 |
from helpers import *
|
| 4 |
from routes.tokenize import tokenize_text
|
|
|
|
| 5 |
router = APIRouter()
|
| 6 |
|
| 7 |
@router.post("", response_model=AttentionResponse)
|
| 8 |
async def get_attention_matrices(request: AttentionRequest):
|
| 9 |
"""Get attention matrices for the input text using the specified model"""
|
| 10 |
try:
|
| 11 |
-
print(f"Processing attention request: text='{request.text}', model={request.model_name}")
|
| 12 |
|
| 13 |
# First tokenize the text using the same function that the /tokenize endpoint uses
|
| 14 |
# to ensure consistency
|
|
@@ -29,7 +30,7 @@ async def get_attention_matrices(request: AttentionRequest):
|
|
| 29 |
base_model_key = f"{model_name}_base"
|
| 30 |
if base_model_key not in models:
|
| 31 |
print(f"Loading base model {model_name}...")
|
| 32 |
-
models[base_model_key] = base_model_class.from_pretrained(model_name)
|
| 33 |
if torch.cuda.is_available():
|
| 34 |
models[base_model_key] = models[base_model_key].cuda()
|
| 35 |
models[base_model_key].eval()
|
|
@@ -46,9 +47,16 @@ async def get_attention_matrices(request: AttentionRequest):
|
|
| 46 |
return_tensors="pt",
|
| 47 |
return_attention_mask=True
|
| 48 |
)
|
|
|
|
|
|
|
|
|
|
| 49 |
else:
|
|
|
|
| 50 |
text = f"[CLS] {request.text} [SEP]"
|
| 51 |
encoding = tokenizer(text, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
if torch.cuda.is_available():
|
| 54 |
encoding = {k: v.cuda() for k, v in encoding.items()}
|
|
@@ -73,6 +81,10 @@ async def get_attention_matrices(request: AttentionRequest):
|
|
| 73 |
attention_matrices = outputs.attentions
|
| 74 |
print(f"Got attention matrices for {len(attention_matrices)} layers")
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# Convert attention matrices to the expected response format
|
| 77 |
layers = []
|
| 78 |
for layer_idx, layer_attention in enumerate(attention_matrices):
|
|
@@ -96,8 +108,23 @@ async def get_attention_matrices(request: AttentionRequest):
|
|
| 96 |
"layerIndex": layer_idx,
|
| 97 |
"heads": heads
|
| 98 |
})
|
| 99 |
-
|
| 100 |
print(f"Processed {len(layers)} layers with {num_heads} heads each")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
# Return complete attention data
|
| 103 |
attention_data = {
|
|
|
|
| 2 |
from classes import *
|
| 3 |
from helpers import *
|
| 4 |
from routes.tokenize import tokenize_text
|
| 5 |
+
from attention_processing import process_attention_with_method
|
| 6 |
router = APIRouter()
|
| 7 |
|
| 8 |
@router.post("", response_model=AttentionResponse)
|
| 9 |
async def get_attention_matrices(request: AttentionRequest):
|
| 10 |
"""Get attention matrices for the input text using the specified model"""
|
| 11 |
try:
|
| 12 |
+
print(f"Processing attention request: text='{request.text}', model={request.model_name}, method={request.visualization_method}")
|
| 13 |
|
| 14 |
# First tokenize the text using the same function that the /tokenize endpoint uses
|
| 15 |
# to ensure consistency
|
|
|
|
| 30 |
base_model_key = f"{model_name}_base"
|
| 31 |
if base_model_key not in models:
|
| 32 |
print(f"Loading base model {model_name}...")
|
| 33 |
+
models[base_model_key] = base_model_class.from_pretrained(model_name, attn_implementation="eager")
|
| 34 |
if torch.cuda.is_available():
|
| 35 |
models[base_model_key] = models[base_model_key].cuda()
|
| 36 |
models[base_model_key].eval()
|
|
|
|
| 47 |
return_tensors="pt",
|
| 48 |
return_attention_mask=True
|
| 49 |
)
|
| 50 |
+
|
| 51 |
+
# Map RoBERTa tokens to words for better visualization
|
| 52 |
+
token_to_word_map = map_roberta_tokens_to_words(tokens, request.text)
|
| 53 |
else:
|
| 54 |
+
# For BERT and DistilBERT
|
| 55 |
text = f"[CLS] {request.text} [SEP]"
|
| 56 |
encoding = tokenizer(text, return_tensors="pt")
|
| 57 |
+
|
| 58 |
+
# Map BERT/DistilBERT tokens to words for better visualization
|
| 59 |
+
token_to_word_map = map_bert_tokens_to_words(tokens, request.text)
|
| 60 |
|
| 61 |
if torch.cuda.is_available():
|
| 62 |
encoding = {k: v.cuda() for k, v in encoding.items()}
|
|
|
|
| 81 |
attention_matrices = outputs.attentions
|
| 82 |
print(f"Got attention matrices for {len(attention_matrices)} layers")
|
| 83 |
|
| 84 |
+
# Process attention using the specified method
|
| 85 |
+
if request.visualization_method != "raw":
|
| 86 |
+
print(f"Processing attention with method: {request.visualization_method}")
|
| 87 |
+
|
| 88 |
# Convert attention matrices to the expected response format
|
| 89 |
layers = []
|
| 90 |
for layer_idx, layer_attention in enumerate(attention_matrices):
|
|
|
|
| 108 |
"layerIndex": layer_idx,
|
| 109 |
"heads": heads
|
| 110 |
})
|
| 111 |
+
|
| 112 |
print(f"Processed {len(layers)} layers with {num_heads} heads each")
|
| 113 |
+
|
| 114 |
+
# Process with selected visualization method
|
| 115 |
+
if request.visualization_method != "raw":
|
| 116 |
+
processed_layers = process_attention_with_method(
|
| 117 |
+
attention_matrices,
|
| 118 |
+
method=request.visualization_method,
|
| 119 |
+
debug=False
|
| 120 |
+
)
|
| 121 |
+
# Replace the layers with the processed ones
|
| 122 |
+
layers = processed_layers
|
| 123 |
+
|
| 124 |
+
# Add token-to-word mapping to the response
|
| 125 |
+
for i, token in enumerate(tokens):
|
| 126 |
+
if i in token_to_word_map:
|
| 127 |
+
token["wordIndex"] = token_to_word_map[i]
|
| 128 |
|
| 129 |
# Return complete attention data
|
| 130 |
attention_data = {
|
routes/attention_comparison.py
CHANGED
|
@@ -16,11 +16,14 @@ async def get_attention_comparison(request: ComparisonRequest):
|
|
| 16 |
print(f"Masked index: {request.masked_index}")
|
| 17 |
print(f"Replacement word: '{request.replacement_word}'")
|
| 18 |
print(f"Model: {request.model_name}")
|
|
|
|
| 19 |
|
| 20 |
# Dispatch based on model type
|
| 21 |
if "roberta" in request.model_name.lower():
|
| 22 |
return await get_attention_comparison_roberta(request)
|
| 23 |
else:
|
|
|
|
|
|
|
| 24 |
return await get_attention_comparison_bert(request)
|
| 25 |
|
| 26 |
|
|
|
|
| 16 |
print(f"Masked index: {request.masked_index}")
|
| 17 |
print(f"Replacement word: '{request.replacement_word}'")
|
| 18 |
print(f"Model: {request.model_name}")
|
| 19 |
+
print(f"Visualization method: {request.visualization_method}")
|
| 20 |
|
| 21 |
# Dispatch based on model type
|
| 22 |
if "roberta" in request.model_name.lower():
|
| 23 |
return await get_attention_comparison_roberta(request)
|
| 24 |
else:
|
| 25 |
+
# Both BERT and DistilBERT use the same tokenization approach (WordPiece)
|
| 26 |
+
# and can use the same comparison implementation
|
| 27 |
return await get_attention_comparison_bert(request)
|
| 28 |
|
| 29 |
|
routes/tokenize.py
CHANGED
|
@@ -27,7 +27,7 @@ async def tokenize_text(request: TokenizeRequest):
|
|
| 27 |
# Clean the tokens to remove the leading 'Ġ' character from RoBERTa tokens
|
| 28 |
tokens = [clean_roberta_token(token) for token in tokens]
|
| 29 |
else:
|
| 30 |
-
# For BERT, add special tokens and tokenize
|
| 31 |
text = f"[CLS] {request.text} [SEP]"
|
| 32 |
tokens = tokenizer.tokenize(text)
|
| 33 |
|
|
|
|
| 27 |
# Clean the tokens to remove the leading 'Ġ' character from RoBERTa tokens
|
| 28 |
tokens = [clean_roberta_token(token) for token in tokens]
|
| 29 |
else:
|
| 30 |
+
# For BERT and DistilBERT, add special tokens and tokenize
|
| 31 |
text = f"[CLS] {request.text} [SEP]"
|
| 32 |
tokens = tokenizer.tokenize(text)
|
| 33 |
|