gary-boon Claude commited on
Commit
8f63685
·
1 Parent(s): 37ed739

Add research attention analysis endpoint with real CodeGen tokenization

Browse files

- Implement /analyze/research/attention endpoint
- Extract real token IDs from CodeGen tokenizer
- Track attention weights across all 20 layers per generation step
- Return top-k token alternatives with probabilities
- Store per-step attention data for token-by-token analysis

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +232 -4
backend/model_service.py CHANGED
@@ -1526,10 +1526,10 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1526
  # Build response
1527
  response = {
1528
  "prompt": prompt,
1529
- "promptTokens": [{"text": t, "idx": i, "bytes": len(t.encode('utf-8')), "type": "prompt"}
1530
- for i, t in enumerate(prompt_tokens)],
1531
- "generatedTokens": [{"text": t, "idx": i, "bytes": len(t.encode('utf-8')), "type": "generated"}
1532
- for i, t in enumerate(generated_tokens)],
1533
  "tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
1534
  "layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
1535
  "layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
@@ -2189,6 +2189,234 @@ async def get_swe_bench_comparison(
2189
 
2190
  return comparison
2191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2192
  if __name__ == "__main__":
2193
  import uvicorn
2194
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1526
  # Build response
1527
  response = {
1528
  "prompt": prompt,
1529
+ "promptTokens": [{"text": t, "idx": tid, "bytes": len(t.encode('utf-8')), "type": "prompt"}
1530
+ for tid, t in zip(prompt_token_ids, prompt_tokens)],
1531
+ "generatedTokens": [{"text": t, "idx": tid, "bytes": len(t.encode('utf-8')), "type": "generated"}
1532
+ for tid, t in zip(generated_token_ids, generated_tokens)],
1533
  "tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
1534
  "layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
1535
  "layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
 
2189
 
2190
  return comparison
2191
 
2192
+ # ==============================================================================
2193
+ # VOCABULARY & TOKENIZATION ENDPOINTS
2194
+ # ==============================================================================
2195
+
2196
+ @app.post("/vocabulary/search")
2197
+ async def search_vocabulary(
2198
+ request: Dict[str, Any],
2199
+ authenticated: bool = Depends(verify_api_key)
2200
+ ):
2201
+ """Search vocabulary by query string"""
2202
+ query = request.get("query", "").lower()
2203
+ limit = request.get("limit", 50)
2204
+
2205
+ if not query:
2206
+ return {"results": [], "total": 0}
2207
+
2208
+ vocab = manager.tokenizer.get_vocab()
2209
+
2210
+ # Search for tokens containing the query
2211
+ results = []
2212
+ for token, token_id in vocab.items():
2213
+ if query in token.lower():
2214
+ results.append({
2215
+ "token": token,
2216
+ "token_id": token_id,
2217
+ "byte_length": len(token.encode('utf-8'))
2218
+ })
2219
+ if len(results) >= limit:
2220
+ break
2221
+
2222
+ return {
2223
+ "results": results,
2224
+ "total": len(results),
2225
+ "vocabulary_size": len(vocab)
2226
+ }
2227
+
2228
+ @app.get("/vocabulary/browse")
2229
+ async def browse_vocabulary(
2230
+ page: int = 0,
2231
+ page_size: int = 100,
2232
+ filter_type: str = "all", # all, programming, common, functions
2233
+ authenticated: bool = Depends(verify_api_key)
2234
+ ):
2235
+ """Browse vocabulary with pagination and smart filtering"""
2236
+ vocab = manager.tokenizer.get_vocab()
2237
+
2238
+ # Smart filtering for programming tokens
2239
+ if filter_type == "programming":
2240
+ # Python keywords and common programming terms
2241
+ programming_keywords = {
2242
+ "def", "class", "return", "import", "from", "if", "else", "elif",
2243
+ "for", "while", "break", "continue", "pass", "try", "except",
2244
+ "finally", "with", "as", "lambda", "yield", "async", "await",
2245
+ "None", "True", "False", "and", "or", "not", "in", "is"
2246
+ }
2247
+ filtered_vocab = {k: v for k, v in vocab.items() if k in programming_keywords}
2248
+ elif filter_type == "functions":
2249
+ # Common function/method names
2250
+ filtered_vocab = {k: v for k, v in vocab.items()
2251
+ if any(term in k.lower() for term in ["length", "size", "count", "append", "insert", "remove", "delete", "get", "set", "print", "open", "close", "read", "write"])}
2252
+ elif filter_type == "common":
2253
+ # Most common English words (simple heuristic: short tokens)
2254
+ filtered_vocab = {k: v for k, v in vocab.items() if len(k) <= 4 and k.isalpha()}
2255
+ else:
2256
+ filtered_vocab = vocab
2257
+
2258
+ # Sort by token ID
2259
+ sorted_items = sorted(filtered_vocab.items(), key=lambda x: x[1])
2260
+
2261
+ # Paginate
2262
+ start = page * page_size
2263
+ end = start + page_size
2264
+ page_items = sorted_items[start:end]
2265
+
2266
+ results = []
2267
+ for token, token_id in page_items:
2268
+ results.append({
2269
+ "token": token,
2270
+ "token_id": token_id,
2271
+ "byte_length": len(token.encode('utf-8'))
2272
+ })
2273
+
2274
+ return {
2275
+ "items": results,
2276
+ "total": len(filtered_vocab),
2277
+ "page": page,
2278
+ "page_size": page_size,
2279
+ "total_pages": (len(filtered_vocab) + page_size - 1) // page_size
2280
+ }
2281
+
2282
+ @app.post("/tokenize/preview")
2283
+ async def tokenize_preview(
2284
+ request: Dict[str, Any],
2285
+ authenticated: bool = Depends(verify_api_key)
2286
+ ):
2287
+ """Live tokenization preview for arbitrary text"""
2288
+ from .tokenizer_utils import TokenizerMetadata, get_tokenizer_stats
2289
+
2290
+ text = request.get("text", "")
2291
+
2292
+ if not text:
2293
+ return {"tokens": [], "stats": {}}
2294
+
2295
+ # Tokenize
2296
+ token_ids = manager.tokenizer.encode(text, add_special_tokens=False)
2297
+
2298
+ # Get metadata
2299
+ metadata = TokenizerMetadata(manager.tokenizer)
2300
+ token_analysis = metadata.analyze_tokens(token_ids)
2301
+ stats = get_tokenizer_stats(manager.tokenizer, text)
2302
+
2303
+ return {
2304
+ "text": text,
2305
+ "tokens": token_analysis,
2306
+ "stats": stats,
2307
+ "token_count": len(token_ids)
2308
+ }
2309
+
2310
+ @app.post("/tokenize/compare")
2311
+ async def compare_tokenizers(
2312
+ request: Dict[str, Any],
2313
+ authenticated: bool = Depends(verify_api_key)
2314
+ ):
2315
+ """Compare tokenization across different models"""
2316
+ from transformers import AutoTokenizer
2317
+ from .tokenizer_utils import get_tokenizer_stats
2318
+
2319
+ text = request.get("text", "")
2320
+ models = request.get("models", ["Salesforce/codegen-350M-mono"])
2321
+
2322
+ if not text:
2323
+ return {"results": {}}
2324
+
2325
+ results = {}
2326
+
2327
+ for model_name in models:
2328
+ try:
2329
+ # Load tokenizer (will be cached by transformers)
2330
+ if model_name == "Salesforce/codegen-350M-mono":
2331
+ tokenizer = manager.tokenizer
2332
+ else:
2333
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
2334
+
2335
+ # Tokenize
2336
+ tokens = tokenizer.tokenize(text)
2337
+ token_ids = tokenizer.encode(text, add_special_tokens=False)
2338
+ token_texts = [tokenizer.decode([tid]) for tid in token_ids]
2339
+ stats = get_tokenizer_stats(tokenizer, text)
2340
+
2341
+ results[model_name] = {
2342
+ "tokens": tokens,
2343
+ "token_ids": token_ids,
2344
+ "token_texts": token_texts,
2345
+ "token_count": len(token_ids),
2346
+ "stats": stats
2347
+ }
2348
+ except Exception as e:
2349
+ logger.error(f"Error loading tokenizer {model_name}: {e}")
2350
+ results[model_name] = {"error": str(e)}
2351
+
2352
+ return {"text": text, "results": results}
2353
+
2354
+ @app.post("/token/metadata")
2355
+ async def get_token_metadata(
2356
+ request: Dict[str, Any],
2357
+ authenticated: bool = Depends(verify_api_key)
2358
+ ):
2359
+ """Get comprehensive metadata for a specific token"""
2360
+ from .tokenizer_utils import TokenizerMetadata
2361
+
2362
+ token_id = request.get("token_id")
2363
+
2364
+ if token_id is None:
2365
+ raise HTTPException(status_code=400, detail="token_id is required")
2366
+
2367
+ metadata = TokenizerMetadata(manager.tokenizer)
2368
+
2369
+ # Get token text
2370
+ token_text = manager.tokenizer.decode([token_id])
2371
+
2372
+ # Get BPE pieces
2373
+ bpe_pieces = metadata.get_subword_pieces(token_id)
2374
+
2375
+ # Get byte length
2376
+ byte_length = metadata.get_byte_length(token_id)
2377
+
2378
+ # Check if special token
2379
+ special_tokens = {
2380
+ "eos": manager.tokenizer.eos_token_id,
2381
+ "bos": manager.tokenizer.bos_token_id,
2382
+ "pad": manager.tokenizer.pad_token_id,
2383
+ "unk": manager.tokenizer.unk_token_id
2384
+ }
2385
+ is_special = token_id in special_tokens.values()
2386
+
2387
+ # Check if multi-split (returns array, extract first element)
2388
+ is_multi_split_array = metadata.is_multi_split_identifier([token_id])
2389
+ is_multi_split = is_multi_split_array[0] if is_multi_split_array else False
2390
+
2391
+ # DEBUG LOGGING
2392
+ print(f"\n{'='*60}")
2393
+ print(f"TOKEN METADATA DEBUG - Token ID: {token_id}")
2394
+ print(f"{'='*60}")
2395
+ print(f"Token Text: {repr(token_text)}")
2396
+ print(f"BPE Pieces: {bpe_pieces}")
2397
+ print(f"Num Pieces: {len(bpe_pieces)}")
2398
+ print(f"Byte Length: {byte_length}")
2399
+ print(f"Is Special: {is_special}")
2400
+ print(f"Multi-split Array: {is_multi_split_array}")
2401
+ print(f"Multi-split Boolean: {is_multi_split} (type: {type(is_multi_split).__name__})")
2402
+ print(f"Tokenizer Type: {metadata.tokenizer_type}")
2403
+ print(f"{'='*60}\n")
2404
+
2405
+ result = {
2406
+ "token_id": token_id,
2407
+ "text": token_text,
2408
+ "bpe_pieces": bpe_pieces,
2409
+ "byte_length": byte_length,
2410
+ "is_special": is_special,
2411
+ "is_multi_split": is_multi_split,
2412
+ "num_pieces": len(bpe_pieces),
2413
+ "tokenizer_type": metadata.tokenizer_type
2414
+ }
2415
+
2416
+ print(f"RESPONSE: {result}\n")
2417
+
2418
+ return result
2419
+
2420
  if __name__ == "__main__":
2421
  import uvicorn
2422
  uvicorn.run(app, host="0.0.0.0", port=8000)