Bonosa2 commited on
Commit
c7e3aa0
Β·
verified Β·
1 Parent(s): afa25f5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import traceback
5
+ import pandas as pd
6
+ import torch
7
+ import gradio as gr
8
+ from transformers import (
9
+ logging,
10
+ AutoProcessor,
11
+ AutoTokenizer,
12
+ AutoModelForImageTextToText
13
+ )
14
+ from sklearn.model_selection import train_test_split
15
+
16
+ # ─── Silence irrelevant warnings ───────────────────────────────────────────────
17
+ logging.set_verbosity_error()
18
+
19
+ # ─── Configuration ────────────────────────────────────────────────────────────
20
+ HF_TOKEN = os.environ.get("HF_TOKEN")
21
+ if not HF_TOKEN:
22
+ raise RuntimeError("Missing HF_TOKEN in env vars – set it under Space Settings β†’ Secrets")
23
+ MODEL_ID = "google/gemma-3n-e2b-it"
24
+
25
+ # ─── Fast startup: load only processor & tokenizer ─────────────────────────────
26
+ processor = AutoProcessor.from_pretrained(
27
+ MODEL_ID, trust_remote_code=True, token=HF_TOKEN
28
+ )
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ MODEL_ID, trust_remote_code=True, token=HF_TOKEN
31
+ )
32
+
33
+ # ─── Heavy work runs on button click ───────────────────────────────────────────
34
+ def generate_and_export():
35
+ try:
36
+ # 1) Lazy‑load the full FP16 model
37
+ model = AutoModelForImageTextToText.from_pretrained(
38
+ MODEL_ID,
39
+ trust_remote_code=True,
40
+ token=HF_TOKEN,
41
+ torch_dtype=torch.float16,
42
+ device_map="auto"
43
+ )
44
+ device = next(model.parameters()).device
45
+
46
+ # 2) Text→SOAP helper
47
+ def to_soap(text: str) -> str:
48
+ inputs = processor.apply_chat_template(
49
+ [
50
+ {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
51
+ {"role":"user", "content":[{"type":"text","text":text}]}
52
+ ],
53
+ add_generation_prompt=True,
54
+ tokenize=True,
55
+ return_tensors="pt",
56
+ return_dict=True
57
+ ).to(device)
58
+ out = model.generate(
59
+ **inputs,
60
+ max_new_tokens=400,
61
+ do_sample=True,
62
+ top_p=0.95,
63
+ temperature=0.1,
64
+ pad_token_id=processor.tokenizer.eos_token_id,
65
+ use_cache=False
66
+ )
67
+ prompt_len = inputs["input_ids"].shape[-1]
68
+ return processor.batch_decode(
69
+ out[:, prompt_len:], skip_special_tokens=True
70
+ )[0].strip()
71
+
72
+ # 3) Generate 20 doc notes + ground truths
73
+ docs, gts = [], []
74
+ for i in range(1, 21):
75
+ doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
76
+ docs.append(doc)
77
+ gts.append(to_soap(doc))
78
+ if i % 5 == 0:
79
+ torch.cuda.empty_cache()
80
+
81
+ # 4) Split into 15 train / 5 test
82
+ df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
83
+ train_df, test_df = train_test_split(df, test_size=5, random_state=42)
84
+
85
+ os.makedirs("outputs", exist_ok=True)
86
+
87
+ # 5) Inference on train split β†’ outputs/inference.tsv
88
+ train_preds = [to_soap(d) for d in train_df["doc_note"]]
89
+ inf = train_df.reset_index(drop=True).copy()
90
+ inf["id"] = inf.index + 1
91
+ inf["predicted_soap"] = train_preds
92
+ inf[["id","ground_truth_soap","predicted_soap"]].to_csv(
93
+ "outputs/inference.tsv", sep="\t", index=False
94
+ )
95
+
96
+ # 6) Inference on test split β†’ outputs/eval.csv
97
+ test_preds = [to_soap(d) for d in test_df["doc_note"]]
98
+ pd.DataFrame({
99
+ "id": range(1, len(test_preds) + 1),
100
+ "predicted_soap": test_preds
101
+ }).to_csv("outputs/eval.csv", index=False)
102
+
103
+ # 7) Return status + file paths for download
104
+ return (
105
+ "βœ… Done with 20 notes (15 train / 5 test)!",
106
+ "outputs/inference.tsv",
107
+ "outputs/eval.csv"
108
+ )
109
+
110
+ except Exception as e:
111
+ traceback.print_exc()
112
+ return (f"❌ Error: {e}", None, None)
113
+
114
+ # ─── Gradio UI ─────────────────────────────────────────────────────────────────
115
+ with gr.Blocks() as demo:
116
+ gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
117
+ btn = gr.Button("Generate & Export 20 Notes")
118
+ status = gr.Textbox(interactive=False, label="Status")
119
+ inf_file = gr.File(label="Download inference.tsv")
120
+ eval_file= gr.File(label="Download eval.csv")
121
+
122
+ btn.click(
123
+ fn=generate_and_export,
124
+ inputs=None,
125
+ outputs=[status, inf_file, eval_file]
126
+ )
127
+
128
+ if __name__ == "__main__":
129
+ demo.launch()