AE-Shree commited on
Commit
28046a7
Β·
1 Parent(s): 58b68f2

Deploy BioStack RLHF Medical Demo

Browse files
Files changed (1) hide show
  1. server.py +20 -23
server.py CHANGED
@@ -141,16 +141,7 @@ class SFTVisionT5Model(nn.Module):
141
  repetition_penalty=1.3,
142
  )
143
 
144
- reports = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
145
- # Strip any leading "Projection: X." prefix that leaked from training data
146
- cleaned = []
147
- for r in reports:
148
- if r.lower().startswith("projection:"):
149
- # Remove the first "Projection: X." segment
150
- parts = r.split(".", 1)
151
- r = parts[1].strip() if len(parts) > 1 else r
152
- cleaned.append(r)
153
- return cleaned
154
 
155
 
156
  # ─────────────────────────────────────────────────────────────────────────────
@@ -191,16 +182,7 @@ class PPOVisionT5Model(nn.Module):
191
  repetition_penalty=1.3,
192
  )
193
 
194
- reports = tokenizer.batch_decode(generated, skip_special_tokens=True)
195
- # Strip any leading "Projection: X." prefix that leaked from training data
196
- cleaned = []
197
- for r in reports:
198
- if r.lower().startswith("projection:"):
199
- # Remove the first "Projection: X." segment
200
- parts = r.split(".", 1)
201
- r = parts[1].strip() if len(parts) > 1 else r
202
- cleaned.append(r)
203
- return cleaned
204
 
205
 
206
  # ─────────────────────────────────────────────────────────────────────────────
@@ -428,7 +410,12 @@ def health():
428
  async def sft_inference(file: UploadFile = File(...)):
429
  try:
430
  tensor = preprocess(await file.read())
431
- report = sft_model.generate_reports(tensor)[0]
 
 
 
 
 
432
  print(f"[SFT] Generated: {report}")
433
  return {"report": report[:81]}
434
  except Exception as e:
@@ -442,7 +429,12 @@ async def reward_inference(file: UploadFile = File(...)):
442
  tensor = preprocess(await file.read())
443
 
444
  # First get the SFT report to score
445
- sft_report = sft_model.generate_reports(tensor)[0]
 
 
 
 
 
446
  print(f"[REWARD] Scoring SFT report: {sft_report}")
447
 
448
  if not sft_report.strip():
@@ -500,7 +492,12 @@ async def reward_inference(file: UploadFile = File(...)):
500
  async def ppo_inference(file: UploadFile = File(...)):
501
  try:
502
  tensor = preprocess(await file.read())
503
- report = ppo_model.generate_reports(tensor)[0]
 
 
 
 
 
504
  print(f"[PPO] Generated: {report}")
505
  return {"report": report}
506
  except Exception as e:
 
141
  repetition_penalty=1.3,
142
  )
143
 
144
+ return generated_ids
 
 
 
 
 
 
 
 
 
145
 
146
 
147
  # ─────────────────────────────────────────────────────────────────────────────
 
182
  repetition_penalty=1.3,
183
  )
184
 
185
+ return generated
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  # ─────────────────────────────────────────────────────────────────────────────
 
410
  async def sft_inference(file: UploadFile = File(...)):
411
  try:
412
  tensor = preprocess(await file.read())
413
+ generated_ids = sft_model.generate_reports(tensor)
414
+ report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
415
+ # Strip any leading "Projection: X." prefix that leaked from training data
416
+ if report.lower().startswith("projection:"):
417
+ parts = report.split(".", 1)
418
+ report = parts[1].strip() if len(parts) > 1 else report
419
  print(f"[SFT] Generated: {report}")
420
  return {"report": report[:81]}
421
  except Exception as e:
 
429
  tensor = preprocess(await file.read())
430
 
431
  # First get the SFT report to score
432
+ sft_generated_ids = sft_model.generate_reports(tensor)
433
+ sft_report = tokenizer.decode(sft_generated_ids[0], skip_special_tokens=True).strip()
434
+ # Strip any leading "Projection: X." prefix that leaked from training data
435
+ if sft_report.lower().startswith("projection:"):
436
+ parts = sft_report.split(".", 1)
437
+ sft_report = parts[1].strip() if len(parts) > 1 else sft_report
438
  print(f"[REWARD] Scoring SFT report: {sft_report}")
439
 
440
  if not sft_report.strip():
 
492
  async def ppo_inference(file: UploadFile = File(...)):
493
  try:
494
  tensor = preprocess(await file.read())
495
+ generated_ids = ppo_model.generate_reports(tensor)
496
+ report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
497
+ # Strip any leading "Projection: X." prefix that leaked from training data
498
+ if report.lower().startswith("projection:"):
499
+ parts = report.split(".", 1)
500
+ report = parts[1].strip() if len(parts) > 1 else report
501
  print(f"[PPO] Generated: {report}")
502
  return {"report": report}
503
  except Exception as e: