AE-Shree commited on
Commit Β·
58b68f2
1
Parent(s): 1537418
Deploy BioStack RLHF Medical Demo
Browse files
server.py
CHANGED
|
@@ -12,6 +12,17 @@ from torchvision import transforms
|
|
| 12 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
# DEVICE
|
| 17 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -19,17 +30,44 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 19 |
print(f"π₯οΈ Using device: {device}")
|
| 20 |
|
| 21 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
-
#
|
| 23 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
# ARCHITECTURE 1 β CoAtNet Encoder (shared by all three models)
|
| 28 |
# Matches BOTH notebooks exactly.
|
| 29 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
class CoAtNetEncoder(nn.Module):
|
| 31 |
-
def __init__(self, model_name=
|
| 32 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# pretrained=False at inference time β weights come from .pt file
|
| 34 |
self.backbone = timm.create_model(model_name, pretrained=pretrained)
|
| 35 |
|
|
@@ -337,38 +375,6 @@ def preprocess(file_bytes: bytes) -> torch.Tensor:
|
|
| 337 |
return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
|
| 338 |
|
| 339 |
|
| 340 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 341 |
-
# DEBUGGING TOOLS - Compare with Jupyter notebook results
|
| 342 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 343 |
-
import hashlib
|
| 344 |
-
|
| 345 |
-
def get_model_hash(model):
|
| 346 |
-
"""Get hash of model state dict for comparison"""
|
| 347 |
-
model_str = str(model.state_dict())
|
| 348 |
-
return hashlib.md5(model_str.encode()).hexdigest()
|
| 349 |
-
|
| 350 |
-
def log_inference_details(model_name, image_tensor, generated_ids, decoded_report):
|
| 351 |
-
"""Detailed logging for debugging inference differences"""
|
| 352 |
-
print(f"\n{'='*50}")
|
| 353 |
-
print(f" {model_name} INFERENCE DEBUG")
|
| 354 |
-
print(f"{'='*50}")
|
| 355 |
-
print(f"Model hash: {get_model_hash(globals()[f'{model_name.lower()}_model'])}")
|
| 356 |
-
print(f"Image tensor shape: {image_tensor.shape}")
|
| 357 |
-
print(f"Image tensor mean: {image_tensor.mean():.6f}")
|
| 358 |
-
print(f"Image tensor std: {image_tensor.std():.6f}")
|
| 359 |
-
print(f"Model in eval mode: {not globals()[f'{model_name.lower()}_model'].training}")
|
| 360 |
-
print(f"Generated IDs: {generated_ids}")
|
| 361 |
-
print(f"Generated IDs shape: {generated_ids.shape}")
|
| 362 |
-
print(f"Decoded report: '{decoded_report}'")
|
| 363 |
-
print(f"Report length: {len(decoded_report)} chars")
|
| 364 |
-
print(f"{'='*50}\n")
|
| 365 |
-
|
| 366 |
-
# Set consistent random seeds for reproducible results
|
| 367 |
-
torch.manual_seed(42)
|
| 368 |
-
torch.cuda.manual_seed_all(42)
|
| 369 |
-
print(" Random seeds set to 42 for reproducible results")
|
| 370 |
-
|
| 371 |
-
|
| 372 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 373 |
# REWARD FEEDBACK GENERATOR
|
| 374 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -422,72 +428,8 @@ def health():
|
|
| 422 |
async def sft_inference(file: UploadFile = File(...)):
|
| 423 |
try:
|
| 424 |
tensor = preprocess(await file.read())
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
print(f"\nπ [SFT] DETAILED INFERENCE ANALYSIS")
|
| 428 |
-
print(f"{'='*60}")
|
| 429 |
-
print(f"Model checkpoint: {SFT_MODEL_PATH}")
|
| 430 |
-
print(f"Image tensor shape: {tensor.shape}")
|
| 431 |
-
print(f"Image tensor device: {tensor.device}")
|
| 432 |
-
print(f"Image tensor mean: {tensor.mean():.6f}")
|
| 433 |
-
print(f"Image tensor std: {tensor.std():.6f}")
|
| 434 |
-
print(f"Model in eval mode: {not sft_model.training}")
|
| 435 |
-
print(f"Using torch.no_grad: True")
|
| 436 |
-
|
| 437 |
-
# Get raw generation output before decoding
|
| 438 |
-
with torch.no_grad():
|
| 439 |
-
img_features = sft_model.img_encoder(tensor)
|
| 440 |
-
img_emb = sft_model.img_proj(img_features).unsqueeze(1)
|
| 441 |
-
batch_size = tensor.size(0)
|
| 442 |
-
img_attn = torch.ones(batch_size, 1, device=tensor.device)
|
| 443 |
-
|
| 444 |
-
encoder_outputs = sft_model.txt_model.encoder(
|
| 445 |
-
inputs_embeds=img_emb,
|
| 446 |
-
attention_mask=img_attn
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
# Log generation parameters
|
| 450 |
-
print(f"Generation parameters:")
|
| 451 |
-
print(f" - max_length: 128")
|
| 452 |
-
print(f" - num_beams: 4")
|
| 453 |
-
print(f" - early_stopping: True")
|
| 454 |
-
print(f" - no_repeat_ngram_size: 3")
|
| 455 |
-
print(f" - repetition_penalty: 1.3")
|
| 456 |
-
print(f" - do_sample: False")
|
| 457 |
-
print(f" - temperature: N/A (deterministic)")
|
| 458 |
-
|
| 459 |
-
generated = sft_model.txt_model.generate(
|
| 460 |
-
encoder_outputs=encoder_outputs,
|
| 461 |
-
attention_mask=img_attn,
|
| 462 |
-
max_length=128,
|
| 463 |
-
num_beams=4,
|
| 464 |
-
early_stopping=True,
|
| 465 |
-
no_repeat_ngram_size=3,
|
| 466 |
-
repetition_penalty=1.3,
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
print(f"Raw generated IDs: {generated}")
|
| 470 |
-
print(f"Generated IDs shape: {generated.shape}")
|
| 471 |
-
|
| 472 |
-
# Decode with same parameters as notebook
|
| 473 |
-
reports = tokenizer.batch_decode(generated, skip_special_tokens=True)
|
| 474 |
-
|
| 475 |
-
# Apply same post-processing
|
| 476 |
-
cleaned_reports = []
|
| 477 |
-
for r in reports:
|
| 478 |
-
if r.lower().startswith("projection:"):
|
| 479 |
-
parts = r.split(".", 1)
|
| 480 |
-
r = parts[1].strip() if len(parts) > 1 else r
|
| 481 |
-
cleaned_reports.append(r)
|
| 482 |
-
|
| 483 |
-
report = cleaned_reports[0]
|
| 484 |
-
|
| 485 |
-
print(f"Decoded report: '{report}'")
|
| 486 |
-
print(f"Report length: {len(report)} chars")
|
| 487 |
-
print(f"Model hash: {get_model_hash(sft_model)}")
|
| 488 |
-
print(f"{'='*60}\n")
|
| 489 |
-
|
| 490 |
-
print(f"[SFT] Final Generated: {report}")
|
| 491 |
return {"report": report[:81]}
|
| 492 |
except Exception as e:
|
| 493 |
traceback.print_exc()
|
|
@@ -570,56 +512,6 @@ async def ppo_inference(file: UploadFile = File(...)):
|
|
| 570 |
# DIAGNOSTIC ENDPOINT β call GET /debug_keys to verify key names in your files
|
| 571 |
# e.g. curl http://localhost:8000/debug_keys
|
| 572 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 573 |
-
@app.get("/debug_compare")
|
| 574 |
-
def debug_compare():
|
| 575 |
-
"""
|
| 576 |
-
Special endpoint to debug inference differences.
|
| 577 |
-
Returns detailed comparison data for troubleshooting.
|
| 578 |
-
"""
|
| 579 |
-
import os
|
| 580 |
-
|
| 581 |
-
comparison_data = {
|
| 582 |
-
"server_info": {
|
| 583 |
-
"device": str(device),
|
| 584 |
-
"torch_version": torch.__version__,
|
| 585 |
-
"transformers_version": transformers.__version__,
|
| 586 |
-
"random_seed": 42,
|
| 587 |
-
"models_loaded": {
|
| 588 |
-
"SFT": os.path.basename(SFT_MODEL_PATH),
|
| 589 |
-
"Reward": os.path.basename(REWARD_MODEL_PATH),
|
| 590 |
-
"PPO": os.path.basename(PPO_MODEL_PATH)
|
| 591 |
-
}
|
| 592 |
-
},
|
| 593 |
-
"model_hashes": {
|
| 594 |
-
"SFT": get_model_hash(sft_model),
|
| 595 |
-
"Reward": get_model_hash(reward_model),
|
| 596 |
-
"PPO": get_model_hash(ppo_model)
|
| 597 |
-
},
|
| 598 |
-
"generation_params": {
|
| 599 |
-
"max_length": 128,
|
| 600 |
-
"num_beams": 4,
|
| 601 |
-
"early_stopping": True,
|
| 602 |
-
"no_repeat_ngram_size": 3,
|
| 603 |
-
"repetition_penalty": 1.3,
|
| 604 |
-
"do_sample": False,
|
| 605 |
-
"temperature": "N/A (deterministic)"
|
| 606 |
-
},
|
| 607 |
-
"preprocessing": {
|
| 608 |
-
"resize": [224, 224],
|
| 609 |
-
"normalize_mean": [0.485, 0.456, 0.406],
|
| 610 |
-
"normalize_std": [0.229, 0.224, 0.225],
|
| 611 |
-
"convert": "RGB"
|
| 612 |
-
},
|
| 613 |
-
"model_states": {
|
| 614 |
-
"SFT_eval": not sft_model.training,
|
| 615 |
-
"Reward_eval": not reward_model.training,
|
| 616 |
-
"PPO_eval": not ppo_model.training
|
| 617 |
-
}
|
| 618 |
-
}
|
| 619 |
-
|
| 620 |
-
return comparison_data
|
| 621 |
-
|
| 622 |
-
|
| 623 |
@app.get("/debug_keys")
|
| 624 |
def debug_keys():
|
| 625 |
import os
|
|
|
|
| 12 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
|
| 15 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
# CONFIGURATION
|
| 17 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
CONFIG = {
|
| 19 |
+
'coatnet_model': 'coatnet_1_rw_224',
|
| 20 |
+
't5_model': 't5-small',
|
| 21 |
+
'img_emb_dim': 768,
|
| 22 |
+
'train_last_stages': 2,
|
| 23 |
+
'image_size': 224,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
# DEVICE
|
| 28 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 30 |
print(f"π₯οΈ Using device: {device}")
|
| 31 |
|
| 32 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
# SECTION 7: Load Tokenizer and Image Transform
|
| 34 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
|
| 36 |
+
print("\n" + "="*80)
|
| 37 |
+
print("LOADING TOKENIZER AND IMAGE TRANSFORM")
|
| 38 |
+
print("="*80)
|
| 39 |
+
|
| 40 |
+
# Load tokenizer
|
| 41 |
+
tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
|
| 42 |
+
print(f"β Loaded tokenizer: {CONFIG['t5_model']}")
|
| 43 |
+
|
| 44 |
+
# Define image transform
|
| 45 |
+
transform = transforms.Compose([
|
| 46 |
+
transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
|
| 47 |
+
transforms.ToTensor(),
|
| 48 |
+
transforms.Normalize(
|
| 49 |
+
mean=[0.485, 0.456, 0.406],
|
| 50 |
+
std=[0.229, 0.224, 0.225]
|
| 51 |
+
)
|
| 52 |
+
])
|
| 53 |
+
print(f"β Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
|
| 54 |
+
|
| 55 |
+
def preprocess_image(image_path: str) -> torch.Tensor:
|
| 56 |
+
"""Load and preprocess image."""
|
| 57 |
+
image = Image.open(image_path).convert('RGB')
|
| 58 |
+
return transform(image)
|
| 59 |
|
| 60 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
# ARCHITECTURE 1 β CoAtNet Encoder (shared by all three models)
|
| 62 |
# Matches BOTH notebooks exactly.
|
| 63 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
class CoAtNetEncoder(nn.Module):
|
| 65 |
+
def __init__(self, model_name=None, pretrained=False, train_last_stages=None):
|
| 66 |
super().__init__()
|
| 67 |
+
# Use CONFIG defaults if not specified
|
| 68 |
+
model_name = model_name or CONFIG['coatnet_model']
|
| 69 |
+
train_last_stages = train_last_stages or CONFIG['train_last_stages']
|
| 70 |
+
|
| 71 |
# pretrained=False at inference time β weights come from .pt file
|
| 72 |
self.backbone = timm.create_model(model_name, pretrained=pretrained)
|
| 73 |
|
|
|
|
| 375 |
return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
|
| 376 |
|
| 377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 379 |
# REWARD FEEDBACK GENERATOR
|
| 380 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 428 |
async def sft_inference(file: UploadFile = File(...)):
|
| 429 |
try:
|
| 430 |
tensor = preprocess(await file.read())
|
| 431 |
+
report = sft_model.generate_reports(tensor)[0]
|
| 432 |
+
print(f"[SFT] Generated: {report}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
return {"report": report[:81]}
|
| 434 |
except Exception as e:
|
| 435 |
traceback.print_exc()
|
|
|
|
| 512 |
# DIAGNOSTIC ENDPOINT β call GET /debug_keys to verify key names in your files
|
| 513 |
# e.g. curl http://localhost:8000/debug_keys
|
| 514 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
@app.get("/debug_keys")
|
| 516 |
def debug_keys():
|
| 517 |
import os
|