AE-Shree commited on
Commit
58b68f2
Β·
1 Parent(s): 1537418

Deploy BioStack RLHF Medical Demo

Browse files
Files changed (1) hide show
  1. server.py +43 -151
server.py CHANGED
@@ -12,6 +12,17 @@ from torchvision import transforms
12
  from transformers import T5ForConditionalGeneration, T5Tokenizer
13
  from huggingface_hub import hf_hub_download
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  # ─────────────────────────────────────────────────────────────────────────────
16
  # DEVICE
17
  # ─────────────────────────────────────────────────────────────────────────────
@@ -19,17 +30,44 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  print(f"πŸ–₯️ Using device: {device}")
20
 
21
  # ─────────────────────────────────────────────────────────────────────────────
22
- # SHARED TOKENIZER
23
  # ─────────────────────────────────────────────────────────────────────────────
24
- tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # ─────────────────────────────────────────────────────────────────────────────
27
  # ARCHITECTURE 1 β€” CoAtNet Encoder (shared by all three models)
28
  # Matches BOTH notebooks exactly.
29
  # ─────────────────────────────────────────────────────────────────────────────
30
  class CoAtNetEncoder(nn.Module):
31
- def __init__(self, model_name="coatnet_1_rw_224", pretrained=False, train_last_stages=2):
32
  super().__init__()
 
 
 
 
33
  # pretrained=False at inference time β€” weights come from .pt file
34
  self.backbone = timm.create_model(model_name, pretrained=pretrained)
35
 
@@ -337,38 +375,6 @@ def preprocess(file_bytes: bytes) -> torch.Tensor:
337
  return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
338
 
339
 
340
- # ─────────────────────────────────────────────────────────────────────────────
341
- # DEBUGGING TOOLS - Compare with Jupyter notebook results
342
- # ─────────────────────────────────────────────────────────────────────────────
343
- import hashlib
344
-
345
- def get_model_hash(model):
346
- """Get hash of model state dict for comparison"""
347
- model_str = str(model.state_dict())
348
- return hashlib.md5(model_str.encode()).hexdigest()
349
-
350
- def log_inference_details(model_name, image_tensor, generated_ids, decoded_report):
351
- """Detailed logging for debugging inference differences"""
352
- print(f"\n{'='*50}")
353
- print(f" {model_name} INFERENCE DEBUG")
354
- print(f"{'='*50}")
355
- print(f"Model hash: {get_model_hash(globals()[f'{model_name.lower()}_model'])}")
356
- print(f"Image tensor shape: {image_tensor.shape}")
357
- print(f"Image tensor mean: {image_tensor.mean():.6f}")
358
- print(f"Image tensor std: {image_tensor.std():.6f}")
359
- print(f"Model in eval mode: {not globals()[f'{model_name.lower()}_model'].training}")
360
- print(f"Generated IDs: {generated_ids}")
361
- print(f"Generated IDs shape: {generated_ids.shape}")
362
- print(f"Decoded report: '{decoded_report}'")
363
- print(f"Report length: {len(decoded_report)} chars")
364
- print(f"{'='*50}\n")
365
-
366
- # Set consistent random seeds for reproducible results
367
- torch.manual_seed(42)
368
- torch.cuda.manual_seed_all(42)
369
- print(" Random seeds set to 42 for reproducible results")
370
-
371
-
372
  # ─────────────────────────────────────────────────────────────────────────────
373
  # REWARD FEEDBACK GENERATOR
374
  # ─────────────────────────────────────────────────────────────────────────────
@@ -422,72 +428,8 @@ def health():
422
  async def sft_inference(file: UploadFile = File(...)):
423
  try:
424
  tensor = preprocess(await file.read())
425
-
426
- # Enhanced debugging - capture generation details
427
- print(f"\nπŸ” [SFT] DETAILED INFERENCE ANALYSIS")
428
- print(f"{'='*60}")
429
- print(f"Model checkpoint: {SFT_MODEL_PATH}")
430
- print(f"Image tensor shape: {tensor.shape}")
431
- print(f"Image tensor device: {tensor.device}")
432
- print(f"Image tensor mean: {tensor.mean():.6f}")
433
- print(f"Image tensor std: {tensor.std():.6f}")
434
- print(f"Model in eval mode: {not sft_model.training}")
435
- print(f"Using torch.no_grad: True")
436
-
437
- # Get raw generation output before decoding
438
- with torch.no_grad():
439
- img_features = sft_model.img_encoder(tensor)
440
- img_emb = sft_model.img_proj(img_features).unsqueeze(1)
441
- batch_size = tensor.size(0)
442
- img_attn = torch.ones(batch_size, 1, device=tensor.device)
443
-
444
- encoder_outputs = sft_model.txt_model.encoder(
445
- inputs_embeds=img_emb,
446
- attention_mask=img_attn
447
- )
448
-
449
- # Log generation parameters
450
- print(f"Generation parameters:")
451
- print(f" - max_length: 128")
452
- print(f" - num_beams: 4")
453
- print(f" - early_stopping: True")
454
- print(f" - no_repeat_ngram_size: 3")
455
- print(f" - repetition_penalty: 1.3")
456
- print(f" - do_sample: False")
457
- print(f" - temperature: N/A (deterministic)")
458
-
459
- generated = sft_model.txt_model.generate(
460
- encoder_outputs=encoder_outputs,
461
- attention_mask=img_attn,
462
- max_length=128,
463
- num_beams=4,
464
- early_stopping=True,
465
- no_repeat_ngram_size=3,
466
- repetition_penalty=1.3,
467
- )
468
-
469
- print(f"Raw generated IDs: {generated}")
470
- print(f"Generated IDs shape: {generated.shape}")
471
-
472
- # Decode with same parameters as notebook
473
- reports = tokenizer.batch_decode(generated, skip_special_tokens=True)
474
-
475
- # Apply same post-processing
476
- cleaned_reports = []
477
- for r in reports:
478
- if r.lower().startswith("projection:"):
479
- parts = r.split(".", 1)
480
- r = parts[1].strip() if len(parts) > 1 else r
481
- cleaned_reports.append(r)
482
-
483
- report = cleaned_reports[0]
484
-
485
- print(f"Decoded report: '{report}'")
486
- print(f"Report length: {len(report)} chars")
487
- print(f"Model hash: {get_model_hash(sft_model)}")
488
- print(f"{'='*60}\n")
489
-
490
- print(f"[SFT] Final Generated: {report}")
491
  return {"report": report[:81]}
492
  except Exception as e:
493
  traceback.print_exc()
@@ -570,56 +512,6 @@ async def ppo_inference(file: UploadFile = File(...)):
570
  # DIAGNOSTIC ENDPOINT β€” call GET /debug_keys to verify key names in your files
571
  # e.g. curl http://localhost:8000/debug_keys
572
  # ─────────────────────────────────────────────────────────────────────────────
573
- @app.get("/debug_compare")
574
- def debug_compare():
575
- """
576
- Special endpoint to debug inference differences.
577
- Returns detailed comparison data for troubleshooting.
578
- """
579
- import os
580
-
581
- comparison_data = {
582
- "server_info": {
583
- "device": str(device),
584
- "torch_version": torch.__version__,
585
- "transformers_version": transformers.__version__,
586
- "random_seed": 42,
587
- "models_loaded": {
588
- "SFT": os.path.basename(SFT_MODEL_PATH),
589
- "Reward": os.path.basename(REWARD_MODEL_PATH),
590
- "PPO": os.path.basename(PPO_MODEL_PATH)
591
- }
592
- },
593
- "model_hashes": {
594
- "SFT": get_model_hash(sft_model),
595
- "Reward": get_model_hash(reward_model),
596
- "PPO": get_model_hash(ppo_model)
597
- },
598
- "generation_params": {
599
- "max_length": 128,
600
- "num_beams": 4,
601
- "early_stopping": True,
602
- "no_repeat_ngram_size": 3,
603
- "repetition_penalty": 1.3,
604
- "do_sample": False,
605
- "temperature": "N/A (deterministic)"
606
- },
607
- "preprocessing": {
608
- "resize": [224, 224],
609
- "normalize_mean": [0.485, 0.456, 0.406],
610
- "normalize_std": [0.229, 0.224, 0.225],
611
- "convert": "RGB"
612
- },
613
- "model_states": {
614
- "SFT_eval": not sft_model.training,
615
- "Reward_eval": not reward_model.training,
616
- "PPO_eval": not ppo_model.training
617
- }
618
- }
619
-
620
- return comparison_data
621
-
622
-
623
  @app.get("/debug_keys")
624
  def debug_keys():
625
  import os
 
12
  from transformers import T5ForConditionalGeneration, T5Tokenizer
13
  from huggingface_hub import hf_hub_download
14
 
15
+ # ─────────────────────────────────────────────────────────────────────────────
16
+ # CONFIGURATION
17
+ # ─────────────────────────────────────────────────────────────────────────────
18
+ CONFIG = {
19
+ 'coatnet_model': 'coatnet_1_rw_224',
20
+ 't5_model': 't5-small',
21
+ 'img_emb_dim': 768,
22
+ 'train_last_stages': 2,
23
+ 'image_size': 224,
24
+ }
25
+
26
  # ─────────────────────────────────────────────────────────────────────────────
27
  # DEVICE
28
  # ─────────────────────────────────────────────────────────────────────────────
 
30
  print(f"πŸ–₯️ Using device: {device}")
31
 
32
  # ─────────────────────────────────────────────────────────────────────────────
33
+ # SECTION 7: Load Tokenizer and Image Transform
34
  # ─────────────────────────────────────────────────────────────────────────────
35
+
36
+ print("\n" + "="*80)
37
+ print("LOADING TOKENIZER AND IMAGE TRANSFORM")
38
+ print("="*80)
39
+
40
+ # Load tokenizer
41
+ tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
42
+ print(f"βœ“ Loaded tokenizer: {CONFIG['t5_model']}")
43
+
44
+ # Define image transform
45
+ transform = transforms.Compose([
46
+ transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize(
49
+ mean=[0.485, 0.456, 0.406],
50
+ std=[0.229, 0.224, 0.225]
51
+ )
52
+ ])
53
+ print(f"βœ“ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
54
+
55
+ def preprocess_image(image_path: str) -> torch.Tensor:
56
+ """Load and preprocess image."""
57
+ image = Image.open(image_path).convert('RGB')
58
+ return transform(image)
59
 
60
  # ─────────────────────────────────────────────────────────────────────────────
61
  # ARCHITECTURE 1 β€” CoAtNet Encoder (shared by all three models)
62
  # Matches BOTH notebooks exactly.
63
  # ─────────────────────────────────────────────────────────────────────────────
64
  class CoAtNetEncoder(nn.Module):
65
+ def __init__(self, model_name=None, pretrained=False, train_last_stages=None):
66
  super().__init__()
67
+ # Use CONFIG defaults if not specified
68
+ model_name = model_name or CONFIG['coatnet_model']
69
+ train_last_stages = train_last_stages or CONFIG['train_last_stages']
70
+
71
  # pretrained=False at inference time β€” weights come from .pt file
72
  self.backbone = timm.create_model(model_name, pretrained=pretrained)
73
 
 
375
  return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
376
 
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  # ─────────────────────────────────────────────────────────────────────────────
379
  # REWARD FEEDBACK GENERATOR
380
  # ─────────────────────────────────────────────────────────────────────────────
 
428
  async def sft_inference(file: UploadFile = File(...)):
429
  try:
430
  tensor = preprocess(await file.read())
431
+ report = sft_model.generate_reports(tensor)[0]
432
+ print(f"[SFT] Generated: {report}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  return {"report": report[:81]}
434
  except Exception as e:
435
  traceback.print_exc()
 
512
  # DIAGNOSTIC ENDPOINT β€” call GET /debug_keys to verify key names in your files
513
  # e.g. curl http://localhost:8000/debug_keys
514
  # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  @app.get("/debug_keys")
516
  def debug_keys():
517
  import os