AE-Shree commited on
Commit Β·
28046a7
1
Parent(s): 58b68f2
Deploy BioStack RLHF Medical Demo
Browse files
server.py
CHANGED
|
@@ -141,16 +141,7 @@ class SFTVisionT5Model(nn.Module):
|
|
| 141 |
repetition_penalty=1.3,
|
| 142 |
)
|
| 143 |
|
| 144 |
-
|
| 145 |
-
# Strip any leading "Projection: X." prefix that leaked from training data
|
| 146 |
-
cleaned = []
|
| 147 |
-
for r in reports:
|
| 148 |
-
if r.lower().startswith("projection:"):
|
| 149 |
-
# Remove the first "Projection: X." segment
|
| 150 |
-
parts = r.split(".", 1)
|
| 151 |
-
r = parts[1].strip() if len(parts) > 1 else r
|
| 152 |
-
cleaned.append(r)
|
| 153 |
-
return cleaned
|
| 154 |
|
| 155 |
|
| 156 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -191,16 +182,7 @@ class PPOVisionT5Model(nn.Module):
|
|
| 191 |
repetition_penalty=1.3,
|
| 192 |
)
|
| 193 |
|
| 194 |
-
|
| 195 |
-
# Strip any leading "Projection: X." prefix that leaked from training data
|
| 196 |
-
cleaned = []
|
| 197 |
-
for r in reports:
|
| 198 |
-
if r.lower().startswith("projection:"):
|
| 199 |
-
# Remove the first "Projection: X." segment
|
| 200 |
-
parts = r.split(".", 1)
|
| 201 |
-
r = parts[1].strip() if len(parts) > 1 else r
|
| 202 |
-
cleaned.append(r)
|
| 203 |
-
return cleaned
|
| 204 |
|
| 205 |
|
| 206 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -428,7 +410,12 @@ def health():
|
|
| 428 |
async def sft_inference(file: UploadFile = File(...)):
|
| 429 |
try:
|
| 430 |
tensor = preprocess(await file.read())
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
print(f"[SFT] Generated: {report}")
|
| 433 |
return {"report": report[:81]}
|
| 434 |
except Exception as e:
|
|
@@ -442,7 +429,12 @@ async def reward_inference(file: UploadFile = File(...)):
|
|
| 442 |
tensor = preprocess(await file.read())
|
| 443 |
|
| 444 |
# First get the SFT report to score
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
print(f"[REWARD] Scoring SFT report: {sft_report}")
|
| 447 |
|
| 448 |
if not sft_report.strip():
|
|
@@ -500,7 +492,12 @@ async def reward_inference(file: UploadFile = File(...)):
|
|
| 500 |
async def ppo_inference(file: UploadFile = File(...)):
|
| 501 |
try:
|
| 502 |
tensor = preprocess(await file.read())
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
print(f"[PPO] Generated: {report}")
|
| 505 |
return {"report": report}
|
| 506 |
except Exception as e:
|
|
|
|
| 141 |
repetition_penalty=1.3,
|
| 142 |
)
|
| 143 |
|
| 144 |
+
return generated_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 182 |
repetition_penalty=1.3,
|
| 183 |
)
|
| 184 |
|
| 185 |
+
return generated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 410 |
async def sft_inference(file: UploadFile = File(...)):
|
| 411 |
try:
|
| 412 |
tensor = preprocess(await file.read())
|
| 413 |
+
generated_ids = sft_model.generate_reports(tensor)
|
| 414 |
+
report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
|
| 415 |
+
# Strip any leading "Projection: X." prefix that leaked from training data
|
| 416 |
+
if report.lower().startswith("projection:"):
|
| 417 |
+
parts = report.split(".", 1)
|
| 418 |
+
report = parts[1].strip() if len(parts) > 1 else report
|
| 419 |
print(f"[SFT] Generated: {report}")
|
| 420 |
return {"report": report[:81]}
|
| 421 |
except Exception as e:
|
|
|
|
| 429 |
tensor = preprocess(await file.read())
|
| 430 |
|
| 431 |
# First get the SFT report to score
|
| 432 |
+
sft_generated_ids = sft_model.generate_reports(tensor)
|
| 433 |
+
sft_report = tokenizer.decode(sft_generated_ids[0], skip_special_tokens=True).strip()
|
| 434 |
+
# Strip any leading "Projection: X." prefix that leaked from training data
|
| 435 |
+
if sft_report.lower().startswith("projection:"):
|
| 436 |
+
parts = sft_report.split(".", 1)
|
| 437 |
+
sft_report = parts[1].strip() if len(parts) > 1 else sft_report
|
| 438 |
print(f"[REWARD] Scoring SFT report: {sft_report}")
|
| 439 |
|
| 440 |
if not sft_report.strip():
|
|
|
|
| 492 |
async def ppo_inference(file: UploadFile = File(...)):
|
| 493 |
try:
|
| 494 |
tensor = preprocess(await file.read())
|
| 495 |
+
generated_ids = ppo_model.generate_reports(tensor)
|
| 496 |
+
report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
|
| 497 |
+
# Strip any leading "Projection: X." prefix that leaked from training data
|
| 498 |
+
if report.lower().startswith("projection:"):
|
| 499 |
+
parts = report.split(".", 1)
|
| 500 |
+
report = parts[1].strip() if len(parts) > 1 else report
|
| 501 |
print(f"[PPO] Generated: {report}")
|
| 502 |
return {"report": report}
|
| 503 |
except Exception as e:
|