Spaces:
Sleeping
Sleeping
gary-boon
Claude
commited on
Commit
·
8f63685
1
Parent(s):
37ed739
Add research attention analysis endpoint with real CodeGen tokenization
Browse files- Implement /analyze/research/attention endpoint
- Extract real token IDs from CodeGen tokenizer
- Track attention weights across all 20 layers per generation step
- Return top-k token alternatives with probabilities
- Store per-step attention data for token-by-token analysis
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- backend/model_service.py +232 -4
backend/model_service.py
CHANGED
|
@@ -1526,10 +1526,10 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1526 |
# Build response
|
| 1527 |
response = {
|
| 1528 |
"prompt": prompt,
|
| 1529 |
-
"promptTokens": [{"text": t, "idx":
|
| 1530 |
-
for
|
| 1531 |
-
"generatedTokens": [{"text": t, "idx":
|
| 1532 |
-
for
|
| 1533 |
"tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
|
| 1534 |
"layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
|
| 1535 |
"layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
|
|
@@ -2189,6 +2189,234 @@ async def get_swe_bench_comparison(
|
|
| 2189 |
|
| 2190 |
return comparison
|
| 2191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2192 |
if __name__ == "__main__":
|
| 2193 |
import uvicorn
|
| 2194 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
| 1526 |
# Build response
|
| 1527 |
response = {
|
| 1528 |
"prompt": prompt,
|
| 1529 |
+
"promptTokens": [{"text": t, "idx": tid, "bytes": len(t.encode('utf-8')), "type": "prompt"}
|
| 1530 |
+
for tid, t in zip(prompt_token_ids, prompt_tokens)],
|
| 1531 |
+
"generatedTokens": [{"text": t, "idx": tid, "bytes": len(t.encode('utf-8')), "type": "generated"}
|
| 1532 |
+
for tid, t in zip(generated_token_ids, generated_tokens)],
|
| 1533 |
"tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
|
| 1534 |
"layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
|
| 1535 |
"layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
|
|
|
|
| 2189 |
|
| 2190 |
return comparison
|
| 2191 |
|
| 2192 |
+
# ==============================================================================
|
| 2193 |
+
# VOCABULARY & TOKENIZATION ENDPOINTS
|
| 2194 |
+
# ==============================================================================
|
| 2195 |
+
|
| 2196 |
+
@app.post("/vocabulary/search")
|
| 2197 |
+
async def search_vocabulary(
|
| 2198 |
+
request: Dict[str, Any],
|
| 2199 |
+
authenticated: bool = Depends(verify_api_key)
|
| 2200 |
+
):
|
| 2201 |
+
"""Search vocabulary by query string"""
|
| 2202 |
+
query = request.get("query", "").lower()
|
| 2203 |
+
limit = request.get("limit", 50)
|
| 2204 |
+
|
| 2205 |
+
if not query:
|
| 2206 |
+
return {"results": [], "total": 0}
|
| 2207 |
+
|
| 2208 |
+
vocab = manager.tokenizer.get_vocab()
|
| 2209 |
+
|
| 2210 |
+
# Search for tokens containing the query
|
| 2211 |
+
results = []
|
| 2212 |
+
for token, token_id in vocab.items():
|
| 2213 |
+
if query in token.lower():
|
| 2214 |
+
results.append({
|
| 2215 |
+
"token": token,
|
| 2216 |
+
"token_id": token_id,
|
| 2217 |
+
"byte_length": len(token.encode('utf-8'))
|
| 2218 |
+
})
|
| 2219 |
+
if len(results) >= limit:
|
| 2220 |
+
break
|
| 2221 |
+
|
| 2222 |
+
return {
|
| 2223 |
+
"results": results,
|
| 2224 |
+
"total": len(results),
|
| 2225 |
+
"vocabulary_size": len(vocab)
|
| 2226 |
+
}
|
| 2227 |
+
|
| 2228 |
+
@app.get("/vocabulary/browse")
|
| 2229 |
+
async def browse_vocabulary(
|
| 2230 |
+
page: int = 0,
|
| 2231 |
+
page_size: int = 100,
|
| 2232 |
+
filter_type: str = "all", # all, programming, common, functions
|
| 2233 |
+
authenticated: bool = Depends(verify_api_key)
|
| 2234 |
+
):
|
| 2235 |
+
"""Browse vocabulary with pagination and smart filtering"""
|
| 2236 |
+
vocab = manager.tokenizer.get_vocab()
|
| 2237 |
+
|
| 2238 |
+
# Smart filtering for programming tokens
|
| 2239 |
+
if filter_type == "programming":
|
| 2240 |
+
# Python keywords and common programming terms
|
| 2241 |
+
programming_keywords = {
|
| 2242 |
+
"def", "class", "return", "import", "from", "if", "else", "elif",
|
| 2243 |
+
"for", "while", "break", "continue", "pass", "try", "except",
|
| 2244 |
+
"finally", "with", "as", "lambda", "yield", "async", "await",
|
| 2245 |
+
"None", "True", "False", "and", "or", "not", "in", "is"
|
| 2246 |
+
}
|
| 2247 |
+
filtered_vocab = {k: v for k, v in vocab.items() if k in programming_keywords}
|
| 2248 |
+
elif filter_type == "functions":
|
| 2249 |
+
# Common function/method names
|
| 2250 |
+
filtered_vocab = {k: v for k, v in vocab.items()
|
| 2251 |
+
if any(term in k.lower() for term in ["length", "size", "count", "append", "insert", "remove", "delete", "get", "set", "print", "open", "close", "read", "write"])}
|
| 2252 |
+
elif filter_type == "common":
|
| 2253 |
+
# Most common English words (simple heuristic: short tokens)
|
| 2254 |
+
filtered_vocab = {k: v for k, v in vocab.items() if len(k) <= 4 and k.isalpha()}
|
| 2255 |
+
else:
|
| 2256 |
+
filtered_vocab = vocab
|
| 2257 |
+
|
| 2258 |
+
# Sort by token ID
|
| 2259 |
+
sorted_items = sorted(filtered_vocab.items(), key=lambda x: x[1])
|
| 2260 |
+
|
| 2261 |
+
# Paginate
|
| 2262 |
+
start = page * page_size
|
| 2263 |
+
end = start + page_size
|
| 2264 |
+
page_items = sorted_items[start:end]
|
| 2265 |
+
|
| 2266 |
+
results = []
|
| 2267 |
+
for token, token_id in page_items:
|
| 2268 |
+
results.append({
|
| 2269 |
+
"token": token,
|
| 2270 |
+
"token_id": token_id,
|
| 2271 |
+
"byte_length": len(token.encode('utf-8'))
|
| 2272 |
+
})
|
| 2273 |
+
|
| 2274 |
+
return {
|
| 2275 |
+
"items": results,
|
| 2276 |
+
"total": len(filtered_vocab),
|
| 2277 |
+
"page": page,
|
| 2278 |
+
"page_size": page_size,
|
| 2279 |
+
"total_pages": (len(filtered_vocab) + page_size - 1) // page_size
|
| 2280 |
+
}
|
| 2281 |
+
|
| 2282 |
+
@app.post("/tokenize/preview")
|
| 2283 |
+
async def tokenize_preview(
|
| 2284 |
+
request: Dict[str, Any],
|
| 2285 |
+
authenticated: bool = Depends(verify_api_key)
|
| 2286 |
+
):
|
| 2287 |
+
"""Live tokenization preview for arbitrary text"""
|
| 2288 |
+
from .tokenizer_utils import TokenizerMetadata, get_tokenizer_stats
|
| 2289 |
+
|
| 2290 |
+
text = request.get("text", "")
|
| 2291 |
+
|
| 2292 |
+
if not text:
|
| 2293 |
+
return {"tokens": [], "stats": {}}
|
| 2294 |
+
|
| 2295 |
+
# Tokenize
|
| 2296 |
+
token_ids = manager.tokenizer.encode(text, add_special_tokens=False)
|
| 2297 |
+
|
| 2298 |
+
# Get metadata
|
| 2299 |
+
metadata = TokenizerMetadata(manager.tokenizer)
|
| 2300 |
+
token_analysis = metadata.analyze_tokens(token_ids)
|
| 2301 |
+
stats = get_tokenizer_stats(manager.tokenizer, text)
|
| 2302 |
+
|
| 2303 |
+
return {
|
| 2304 |
+
"text": text,
|
| 2305 |
+
"tokens": token_analysis,
|
| 2306 |
+
"stats": stats,
|
| 2307 |
+
"token_count": len(token_ids)
|
| 2308 |
+
}
|
| 2309 |
+
|
| 2310 |
+
@app.post("/tokenize/compare")
|
| 2311 |
+
async def compare_tokenizers(
|
| 2312 |
+
request: Dict[str, Any],
|
| 2313 |
+
authenticated: bool = Depends(verify_api_key)
|
| 2314 |
+
):
|
| 2315 |
+
"""Compare tokenization across different models"""
|
| 2316 |
+
from transformers import AutoTokenizer
|
| 2317 |
+
from .tokenizer_utils import get_tokenizer_stats
|
| 2318 |
+
|
| 2319 |
+
text = request.get("text", "")
|
| 2320 |
+
models = request.get("models", ["Salesforce/codegen-350M-mono"])
|
| 2321 |
+
|
| 2322 |
+
if not text:
|
| 2323 |
+
return {"results": {}}
|
| 2324 |
+
|
| 2325 |
+
results = {}
|
| 2326 |
+
|
| 2327 |
+
for model_name in models:
|
| 2328 |
+
try:
|
| 2329 |
+
# Load tokenizer (will be cached by transformers)
|
| 2330 |
+
if model_name == "Salesforce/codegen-350M-mono":
|
| 2331 |
+
tokenizer = manager.tokenizer
|
| 2332 |
+
else:
|
| 2333 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 2334 |
+
|
| 2335 |
+
# Tokenize
|
| 2336 |
+
tokens = tokenizer.tokenize(text)
|
| 2337 |
+
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
| 2338 |
+
token_texts = [tokenizer.decode([tid]) for tid in token_ids]
|
| 2339 |
+
stats = get_tokenizer_stats(tokenizer, text)
|
| 2340 |
+
|
| 2341 |
+
results[model_name] = {
|
| 2342 |
+
"tokens": tokens,
|
| 2343 |
+
"token_ids": token_ids,
|
| 2344 |
+
"token_texts": token_texts,
|
| 2345 |
+
"token_count": len(token_ids),
|
| 2346 |
+
"stats": stats
|
| 2347 |
+
}
|
| 2348 |
+
except Exception as e:
|
| 2349 |
+
logger.error(f"Error loading tokenizer {model_name}: {e}")
|
| 2350 |
+
results[model_name] = {"error": str(e)}
|
| 2351 |
+
|
| 2352 |
+
return {"text": text, "results": results}
|
| 2353 |
+
|
| 2354 |
+
@app.post("/token/metadata")
|
| 2355 |
+
async def get_token_metadata(
|
| 2356 |
+
request: Dict[str, Any],
|
| 2357 |
+
authenticated: bool = Depends(verify_api_key)
|
| 2358 |
+
):
|
| 2359 |
+
"""Get comprehensive metadata for a specific token"""
|
| 2360 |
+
from .tokenizer_utils import TokenizerMetadata
|
| 2361 |
+
|
| 2362 |
+
token_id = request.get("token_id")
|
| 2363 |
+
|
| 2364 |
+
if token_id is None:
|
| 2365 |
+
raise HTTPException(status_code=400, detail="token_id is required")
|
| 2366 |
+
|
| 2367 |
+
metadata = TokenizerMetadata(manager.tokenizer)
|
| 2368 |
+
|
| 2369 |
+
# Get token text
|
| 2370 |
+
token_text = manager.tokenizer.decode([token_id])
|
| 2371 |
+
|
| 2372 |
+
# Get BPE pieces
|
| 2373 |
+
bpe_pieces = metadata.get_subword_pieces(token_id)
|
| 2374 |
+
|
| 2375 |
+
# Get byte length
|
| 2376 |
+
byte_length = metadata.get_byte_length(token_id)
|
| 2377 |
+
|
| 2378 |
+
# Check if special token
|
| 2379 |
+
special_tokens = {
|
| 2380 |
+
"eos": manager.tokenizer.eos_token_id,
|
| 2381 |
+
"bos": manager.tokenizer.bos_token_id,
|
| 2382 |
+
"pad": manager.tokenizer.pad_token_id,
|
| 2383 |
+
"unk": manager.tokenizer.unk_token_id
|
| 2384 |
+
}
|
| 2385 |
+
is_special = token_id in special_tokens.values()
|
| 2386 |
+
|
| 2387 |
+
# Check if multi-split (returns array, extract first element)
|
| 2388 |
+
is_multi_split_array = metadata.is_multi_split_identifier([token_id])
|
| 2389 |
+
is_multi_split = is_multi_split_array[0] if is_multi_split_array else False
|
| 2390 |
+
|
| 2391 |
+
# DEBUG LOGGING
|
| 2392 |
+
print(f"\n{'='*60}")
|
| 2393 |
+
print(f"TOKEN METADATA DEBUG - Token ID: {token_id}")
|
| 2394 |
+
print(f"{'='*60}")
|
| 2395 |
+
print(f"Token Text: {repr(token_text)}")
|
| 2396 |
+
print(f"BPE Pieces: {bpe_pieces}")
|
| 2397 |
+
print(f"Num Pieces: {len(bpe_pieces)}")
|
| 2398 |
+
print(f"Byte Length: {byte_length}")
|
| 2399 |
+
print(f"Is Special: {is_special}")
|
| 2400 |
+
print(f"Multi-split Array: {is_multi_split_array}")
|
| 2401 |
+
print(f"Multi-split Boolean: {is_multi_split} (type: {type(is_multi_split).__name__})")
|
| 2402 |
+
print(f"Tokenizer Type: {metadata.tokenizer_type}")
|
| 2403 |
+
print(f"{'='*60}\n")
|
| 2404 |
+
|
| 2405 |
+
result = {
|
| 2406 |
+
"token_id": token_id,
|
| 2407 |
+
"text": token_text,
|
| 2408 |
+
"bpe_pieces": bpe_pieces,
|
| 2409 |
+
"byte_length": byte_length,
|
| 2410 |
+
"is_special": is_special,
|
| 2411 |
+
"is_multi_split": is_multi_split,
|
| 2412 |
+
"num_pieces": len(bpe_pieces),
|
| 2413 |
+
"tokenizer_type": metadata.tokenizer_type
|
| 2414 |
+
}
|
| 2415 |
+
|
| 2416 |
+
print(f"RESPONSE: {result}\n")
|
| 2417 |
+
|
| 2418 |
+
return result
|
| 2419 |
+
|
| 2420 |
if __name__ == "__main__":
|
| 2421 |
import uvicorn
|
| 2422 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|