from fastapi import HTTPException
from classes import *
from helpers import *
from routes.attention import get_attention_matrices
from routes.tokenize import tokenize_text
async def get_attention_comparison_bert(request: ComparisonRequest):
"""
BERT and DistilBERT implementation of attention comparison
"""
try:
model_type = "DistilBERT" if "distilbert" in request.model_name.lower() else "BERT"
print(f"\n=== USING {model_type} ATTENTION COMPARISON IMPLEMENTATION ===")
# 1. Get the "before" attention data
before_attention_request = AttentionRequest(
text=request.text,
model_name=request.model_name,
visualization_method=request.visualization_method
)
before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
# 2. Tokenize the text
tokenizer_response = await tokenize_text(TokenizeRequest(text=request.text, model_name=request.model_name))
tokens = tokenizer_response["tokens"]
# Print all tokens for debugging
print(f"\nTokens ({len(tokens)}):")
for i, t in enumerate(tokens):
print(f" {i}: '{t['text']}'")
# Validate masked_index
if request.masked_index < 0 or request.masked_index >= len(tokens):
raise HTTPException(status_code=400, detail=f"Invalid token index {request.masked_index}. Valid range: 0-{len(tokens)-1}")
# Get the selected token
selected_token = tokens[request.masked_index]["text"]
print(f"\nSelected token at index {request.masked_index}: '{selected_token}'")
# Get the tokenizer for this model
_, tokenizer = get_model_and_tokenizer(request.model_name)
# Detect if we're working with a punctuation token
is_punctuation = selected_token in [".", ",", "!", "?", ":", ";", "-", "'", "\""]
print(f"Is punctuation token: {is_punctuation}")
# Get full original text
original_text = request.text
# HANDLE PUNCTUATION
if is_punctuation:
print(f"\nUsing {model_type} punctuation replacement approach")
# Find all occurrences of this punctuation in the original text
punctuation_positions = [pos for pos, char in enumerate(original_text) if char == selected_token]
print(f"Found punctuation '{selected_token}' at positions: {punctuation_positions}")
if not punctuation_positions:
print(f"Warning: Could not find punctuation '{selected_token}' in text, using fallback")
# Fallback to word-based approach
is_punctuation = False
else:
# Determine which occurrence of the punctuation corresponds to our token
# We'll use a heuristic based on the token's position
# Count non-special tokens before our selected token
special_tokens = ["[CLS]", "[SEP]"] # Same special tokens for BERT and DistilBERT
non_special_tokens_before = sum(1 for t in tokens[:request.masked_index]
if t["text"] not in special_tokens)
# Select the corresponding punctuation position (or last one if out of range)
punct_idx = min(non_special_tokens_before, len(punctuation_positions) - 1)
position_to_replace = punctuation_positions[punct_idx]
print(f"Selected punctuation occurrence {punct_idx} at position {position_to_replace}")
# Replace just the punctuation character
replaced_text = original_text[:position_to_replace] + request.replacement_word + original_text[position_to_replace+1:]
print(f"Original text: '{original_text}'")
print(f"Replaced text: '{replaced_text}'")
# Get the after attention data
after_attention_request = AttentionRequest(
text=replaced_text,
model_name=request.model_name,
visualization_method=request.visualization_method
)
after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
# Return comparison data
return {"before_attention": before_data, "after_attention": after_data}
# HANDLE REGULAR WORDS FOR BERT/DistilBERT
print(f"\nUsing {model_type} word replacement approach")
words = original_text.split()
print(f"Words: {words}")
# Build a mapping of token indices to original text positions
token_positions = []
current_pos = 0
for token in tokens:
# Skip special tokens
if token["text"] in ["[CLS]", "[SEP]"]:
token_positions.append(None)
continue
# For regular tokens, find their position in the original text
token_text = token["text"].replace("##", "")
# Find the token in the original text starting from current position
start_pos = original_text.lower().find(token_text.lower(), current_pos)
if start_pos != -1:
token_positions.append((start_pos, start_pos + len(token_text)))
current_pos = start_pos + len(token_text)
else:
# If token not found directly, it might be due to case sensitivity or special handling
token_positions.append(None)
print(f"Token positions: {token_positions}")
# Now determine which word(s) correspond to our selected token
if request.masked_index < len(token_positions) and token_positions[request.masked_index] is not None:
token_start, token_end = token_positions[request.masked_index]
# Find which word contains this token
current_pos = 0
target_word_idx = None
for i, word in enumerate(words):
word_start = original_text.lower().find(word.lower(), current_pos)
if word_start == -1: # Skip if word not found
continue
word_end = word_start + len(word)
# Check if token is within this word
if (token_start >= word_start and token_start < word_end) or \
(token_end > word_start and token_end <= word_end):
target_word_idx = i
break
current_pos = word_end
if target_word_idx is not None:
print(f"Selected token maps to word {target_word_idx}: '{words[target_word_idx]}'")
# Check if the word has punctuation at the end
original_word = words[target_word_idx]
punctuation_suffix = ""
for char in ['.', ',', '!', '?', ':', ';']:
if original_word.endswith(char):
punctuation_suffix = char
break
# Replace the word, preserving any punctuation
replaced_word = request.replacement_word
if punctuation_suffix and not replaced_word.endswith(punctuation_suffix):
replaced_word = replaced_word + punctuation_suffix
print(f"Preserving punctuation: {request.replacement_word} → {replaced_word}")
# Create the new text
words[target_word_idx] = replaced_word
replaced_text = " ".join(words)
print(f"Original text: '{original_text}'")
print(f"Original word: '{original_word}'")
print(f"Replacement: '{replaced_word}'")
print(f"Replaced text: '{replaced_text}'")
else:
# Fallback: replace the word closest to the token position
print(f"Could not map token to a specific word, using fallback")
# Use a simple approach: split by spaces and replace the closest word
# Adjust the index to account for [CLS] token
adjusted_index = max(0, request.masked_index - 1)
word_idx = min(adjusted_index, len(words) - 1)
# Check for punctuation
original_word = words[word_idx]
punctuation_suffix = ""
for char in ['.', ',', '!', '?', ':', ';']:
if original_word.endswith(char):
punctuation_suffix = char
break
# Replace the word, preserving any punctuation
replaced_word = request.replacement_word
if punctuation_suffix and not replaced_word.endswith(punctuation_suffix):
replaced_word = replaced_word + punctuation_suffix
words[word_idx] = replaced_word
replaced_text = " ".join(words)
print(f"Fallback replacement: '{original_word}' → '{replaced_word}'")
print(f"Replaced text: '{replaced_text}'")
else:
# Fallback if we couldn't find token position
print(f"Could not determine token position, using simple word replacement")
words = original_text.split()
# Adjust for special tokens in BERT/DistilBERT ([CLS])
adjusted_index = max(0, request.masked_index - 1)
word_idx = min(adjusted_index, len(words) - 1)
# Check for punctuation
original_word = words[word_idx]
punctuation_suffix = ""
for char in ['.', ',', '!', '?', ':', ';']:
if original_word.endswith(char):
punctuation_suffix = char
break
# Replace the word, preserving any punctuation
replaced_word = request.replacement_word
if punctuation_suffix and not replaced_word.endswith(punctuation_suffix):
replaced_word = replaced_word + punctuation_suffix
words[word_idx] = replaced_word
replaced_text = " ".join(words)
print(f"Simple replacement: '{original_word}' → '{replaced_word}'")
print(f"Replaced text: '{replaced_text}'")
# Get the after attention data
after_attention_request = AttentionRequest(
text=replaced_text,
model_name=request.model_name,
visualization_method=request.visualization_method
)
after_data = (await get_attention_matrices(after_attention_request))["attention_data"]
# Return comparison data
return {"before_attention": before_data, "after_attention": after_data}
except Exception as e:
print(f"{model_type} Attention comparison error: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
async def get_attention_comparison_roberta(request: ComparisonRequest):
"""
RoBERTa-specific implementation of attention comparison.
Completely rewritten to properly handle token replacement.
"""
try:
print(f"\n=== ROBERTA ATTENTION COMPARISON ===")
print(f"Text: '{request.text}'")
print(f"Selected token index: {request.masked_index}")
print(f"Replacement word: '{request.replacement_word}'")
# Get the "before" attention data
before_attention_request = AttentionRequest(
text=request.text,
model_name=request.model_name,
visualization_method=request.visualization_method
)
before_data = (await get_attention_matrices(before_attention_request))["attention_data"]
# Tokenize the text
tokenizer_response = await tokenize_text(TokenizeRequest(text=request.text, model_name=request.model_name))
tokens = tokenizer_response["tokens"]
# Get the tokenizer
_, tokenizer = get_model_and_tokenizer(request.model_name)
# Log tokens
print("\nTokens:")
for i, t in enumerate(tokens):
print(f" {i}: '{t['text']}'")
# Validate masked_index
if request.masked_index < 0 or request.masked_index >= len(tokens):
raise HTTPException(status_code=400, detail=f"Invalid token index {request.masked_index}. Valid range: 0-{len(tokens)-1}")
# Get the selected token
selected_token = tokens[request.masked_index]["text"]
print(f"Selected token: '{selected_token}' at index {request.masked_index}")
# Get original text as a list of words
original_text = request.text
words = original_text.split()
# Step 1: Direct handling for punctuation tokens
is_punctuation = selected_token in [".", ",", "!", "?", ":", ";", "-", "'", "\""]
if is_punctuation:
print(f"Handling punctuation token: '{selected_token}'")
# Find all occurrences of this punctuation in the original text
punctuation_positions = [pos for pos, char in enumerate(original_text) if char == selected_token]
if punctuation_positions:
# Decide which occurrence to replace based on position
if len(punctuation_positions) == 1:
# Only one occurrence - clear choice
pos_to_replace = punctuation_positions[0]
elif selected_token == "." and original_text.endswith("."):
# End-of-sentence period
pos_to_replace = len(original_text) - 1
else:
# Count non-special tokens before our selected token to guess which occurrence
non_special_count = sum(1 for i, t in enumerate(tokens)
if i < request.masked_index and t["text"] not in ["", "", ""])
# Use the count (bounded) to select which occurrence
pos_idx = min(non_special_count, len(punctuation_positions) - 1)
pos_to_replace = punctuation_positions[pos_idx]
# Perform the replacement
print(f"Replacing punctuation at position {pos_to_replace}")
replaced_text = original_text[:pos_to_replace] + request.replacement_word + original_text[pos_to_replace+1:]
print(f"Replaced text: '{replaced_text}'")
# Get the "after" attention data and return
after_request = AttentionRequest(
text=replaced_text,
model_name=request.model_name,
visualization_method=request.visualization_method
)
after_data = (await get_attention_matrices(after_request))["attention_data"]
return {"before_attention": before_data, "after_attention": after_data}
# Step 2: Map the token to a word
print("Mapping selected token to a word:")
token_to_word_map = map_roberta_tokens_to_words(tokens, original_text)
# Get the word index for the selected token
if request.masked_index in token_to_word_map:
word_idx = token_to_word_map[request.masked_index]
if word_idx < 0 or word_idx >= len(words):
print(f"Warning: word_idx {word_idx} is out of bounds, using nearest valid index")
word_idx = max(0, min(word_idx, len(words) - 1))
original_word = words[word_idx]
print(f"Selected token maps to word '{original_word}' at index {word_idx}")
# Handle any punctuation at the end of the word
punctuation_suffix = ""
for char in ['.', ',', '!', '?', ':', ';']:
if original_word.endswith(char):
punctuation_suffix = char
break
# Create replacement word with punctuation preserved if needed
if punctuation_suffix:
replaced_word = request.replacement_word + punctuation_suffix
print(f"Preserving punctuation: '{request.replacement_word}' → '{replaced_word}'")
else:
replaced_word = request.replacement_word
# Create the replaced text
words[word_idx] = replaced_word
replaced_text = " ".join(words)
print(f"Replacing '{original_word}' with '{replaced_word}'")
print(f"Replaced text: '{replaced_text}'")
# Get the "after" attention data
after_request = AttentionRequest(
text=replaced_text,
model_name=request.model_name,
visualization_method=request.visualization_method
)
after_data = (await get_attention_matrices(after_request))["attention_data"]
return {"before_attention": before_data, "after_attention": after_data}
else:
# Step 3: Fallback - direct content matching
print(f"Selected token not found in mapping, using fallback approach")
clean_token = selected_token.lower()
# Try to find a direct match in any word
matching_word_idx = -1
for i, word in enumerate(words):
word_lower = word.lower().rstrip(".,!?;:")
if clean_token == word_lower or clean_token in word_lower:
matching_word_idx = i
print(f"Direct match: token '{selected_token}' → word '{word}'")
break
if matching_word_idx >= 0:
# Replace the matched word
original_word = words[matching_word_idx]
# Preserve punctuation if present
punctuation_suffix = ""
for char in ['.', ',', '!', '?', ':', ';']:
if original_word.endswith(char):
punctuation_suffix = char
break
if punctuation_suffix:
replaced_word = request.replacement_word + punctuation_suffix
else:
replaced_word = request.replacement_word
words[matching_word_idx] = replaced_word
replaced_text = " ".join(words)
print(f"Replacing '{original_word}' with '{replaced_word}'")
print(f"Replaced text: '{replaced_text}'")
else:
# Step 4: Absolute fallback - position-based replacement
print("No word match found, using position-based fallback")
# Count non-special tokens before our token to estimate word position
non_special_count = 0
for i, t in enumerate(tokens):
if i < request.masked_index and t["text"] not in ["", "", ""]:
non_special_count += 1
# Map to a word index (bounded)
word_idx = min(non_special_count, len(words) - 1)
original_word = words[word_idx]
# Preserve punctuation
punctuation_suffix = ""
for char in ['.', ',', '!', '?', ':', ';']:
if original_word.endswith(char):
punctuation_suffix = char
break
if punctuation_suffix:
replaced_word = request.replacement_word + punctuation_suffix
else:
replaced_word = request.replacement_word
words[word_idx] = replaced_word
replaced_text = " ".join(words)
print(f"Position-based replacement: '{original_word}' → '{replaced_word}'")
print(f"Replaced text: '{replaced_text}'")
# Get the "after" attention data
after_request = AttentionRequest(
text=replaced_text,
model_name=request.model_name,
visualization_method=request.visualization_method
)
after_data = (await get_attention_matrices(after_request))["attention_data"]
return {"before_attention": before_data, "after_attention": after_data}
except Exception as e:
print(f"RoBERTa Attention comparison error: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))