Spaces:
Running
Running
| import gradio as gr | |
| import os, requests, io, json | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from groq import Groq | |
| from PIL import Image | |
| from datetime import datetime | |
| from huggingface_hub import HfApi, hf_hub_download | |
| GROQ_KEY = os.environ.get("GROQ_API_KEY", "") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| HISTORY_REPO = "Saicharan21/cardiolab-chat-history" | |
| PAPERS_DB_REPO = "Saicharan21/cardiolab-papers-db" | |
| CARDIOLAB_MODEL_REPO = "Saicharan21/CardioLab-AI-Model" | |
| CHAT_MODELS = { | |
| "Llama 3.3 70B (Best)": "llama-3.3-70b-versatile", | |
| "Llama 3.1 8B (Fast)": "llama-3.1-8b-instant", | |
| "Llama 4 Scout": "meta-llama/llama-4-scout-17b-16e-instruct", | |
| "Llama 4 Maverick": "meta-llama/llama-4-maverick-17b-128e-instruct", | |
| } | |
| KNOWHOW = ("MCL: Sylgard 184 PDMS 10:1 ratio 48hr cure green laser PIV 70bpm 5L/min. " | |
| "TGT: Arduino Uno Stepper Motor 150mL blood 0 20 40 60min TAT PF1.2 hemolysis platelets. " | |
| "NORMAL: TAT below 8. PF1.2 below 2.0. Hemo below 20. Plt above 150. " | |
| "uPAD: Jaffe reaction creatinine picric acid orange-red. Normal 0.6-1.2 mg/dL. CKD above 1.5. " | |
| "MHV: 27mm SJM Regent bileaflet trileaflet monoleaflet pediatric. " | |
| "PIV: green laser 532nm. Normal velocity 0.5-2.0 m/s. Shear below 5 Pa. Risk above 10 Pa. " | |
| "Equipment: Heska HT5 analyzer PIV Tygon tubing Arduino Uno.") | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
| body, .gradio-container { background: #f8fafc !important; font-family: Inter, sans-serif !important; } | |
| .tab-nav { background: #fff !important; border-bottom: 1px solid #e2e8f0 !important; padding: 0 8px !important; display: flex !important; flex-wrap: wrap !important; } | |
| .tab-nav button { background: transparent !important; color: #64748b !important; border: none !important; border-bottom: 2px solid transparent !important; border-radius: 0 !important; padding: 10px 12px !important; font-weight: 500 !important; font-size: 0.8em !important; white-space: nowrap !important; margin-bottom: -1px !important; } | |
| .tab-nav button:hover { color: #c1121f !important; background: #fff5f5 !important; } | |
| .tab-nav button.selected { color: #c1121f !important; border-bottom: 2px solid #c1121f !important; font-weight: 700 !important; background: transparent !important; } | |
| .message.user { background: linear-gradient(135deg, #c1121f, #e63946) !important; color: white !important; border-radius: 14px 14px 4px 14px !important; padding: 12px 16px !important; } | |
| .message.bot { background: #ffffff !important; color: #1a202c !important; border: 1px solid #e2e8f0 !important; border-left: 3px solid #c1121f !important; border-radius: 4px 14px 14px 14px !important; padding: 12px 16px !important; } | |
| textarea { background: #fff !important; color: #1a202c !important; border: 1px solid #e2e8f0 !important; border-radius: 10px !important; } | |
| textarea:focus { border-color: #c1121f !important; outline: none !important; box-shadow: 0 0 0 2px rgba(193,18,31,0.1) !important; } | |
| button.primary { background: #c1121f !important; color: white !important; border: none !important; border-radius: 8px !important; font-weight: 600 !important; } | |
| button.primary:hover { background: #a00e18 !important; transform: translateY(-1px) !important; } | |
| button.secondary { background: #f1f5f9 !important; color: #475569 !important; border: 1px solid #e2e8f0 !important; border-radius: 8px !important; } | |
| input[type=number] { background: #fff !important; color: #1a202c !important; border: 1px solid #e2e8f0 !important; border-radius: 8px !important; } | |
| label span { color: #475569 !important; font-weight: 500 !important; font-size: 0.82em !important; } | |
| ::-webkit-scrollbar { width: 5px; } | |
| ::-webkit-scrollbar-thumb { background: #c1121f; border-radius: 4px; } | |
| """ | |
| HEADER = """ | |
| <style> | |
| @keyframes hb{0%,100%{transform:scale(1)}15%{transform:scale(1.14)}30%{transform:scale(1)}45%{transform:scale(1.08)}60%{transform:scale(1)}} | |
| @keyframes ecg{from{stroke-dashoffset:400}to{stroke-dashoffset:0}} | |
| @keyframes fadeD{from{opacity:0;transform:translateY(-8px)}to{opacity:1;transform:translateY(0)}} | |
| </style> | |
| <div style="background:#fff;border-bottom:2px solid #c1121f;padding:14px 24px;display:flex;align-items:center;justify-content:space-between;box-shadow:0 1px 8px rgba(0,0,0,0.06);animation:fadeD 0.4s ease;"> | |
| <div style="display:flex;align-items:center;gap:10px;"> | |
| <div style="background:#eff6ff;border:1px solid #bfdbfe;border-radius:10px;padding:7px 12px;display:flex;align-items:center;gap:8px;"> | |
| <svg width="20" height="20" viewBox="0 0 100 100"> | |
| <circle cx="50" cy="35" r="28" fill="#0057a8"/><ellipse cx="50" cy="14" rx="18" ry="8" fill="#0057a8"/> | |
| <polygon points="35,12 37,5 40,12" fill="#e8a020"/><polygon points="40,11 43,4 46,11" fill="#e8a020"/> | |
| <polygon points="46,11 49,4 52,11" fill="#e8a020"/><polygon points="52,11 55,4 58,11" fill="#e8a020"/> | |
| <polygon points="58,12 61,5 64,12" fill="#e8a020"/> | |
| <rect x="38" y="30" width="24" height="18" rx="3" fill="#0057a8"/> | |
| <rect x="42" y="34" width="6" height="10" rx="2" fill="#e8a020"/> | |
| <rect x="36" y="46" width="28" height="6" rx="3" fill="#0057a8"/> | |
| </svg> | |
| <div> | |
| <div style="color:#1d4ed8;font-size:0.68em;font-weight:700;line-height:1.2;">SJSU</div> | |
| <div style="color:#374151;font-size:0.6em;line-height:1.2;">Biomedical Eng.</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="display:flex;align-items:center;gap:14px;"> | |
| <svg width="80" height="22" viewBox="0 0 100 22"> | |
| <polyline points="0,11 18,11 23,3 27,19 31,1 35,17 39,11 100,11" fill="none" stroke="#c1121f" stroke-width="2" stroke-linecap="round" stroke-dasharray="400" style="animation:ecg 1.5s ease forwards;"/> | |
| </svg> | |
| <div style="display:flex;align-items:center;gap:12px;"> | |
| <div style="animation:hb 1.4s ease infinite;"> | |
| <svg width="34" height="30" viewBox="0 0 100 90"> | |
| <defs><radialGradient id="hg" cx="50%" cy="35%"><stop offset="0%" stop-color="#e63946"/><stop offset="100%" stop-color="#9b0a14"/></radialGradient></defs> | |
| <path d="M50 85 C50 85 5 55 5 30 C5 15 18 5 30 5 C38 5 45 9 50 15 C55 9 62 5 70 5 C82 5 95 15 95 30 C95 55 50 85 50 85Z" fill="url(#hg)"/> | |
| <polyline points="22,46 30,46 34,35 38,57 42,28 46,51 52,46 78,46" fill="none" stroke="white" stroke-width="3.5" stroke-linecap="round" opacity="0.95"/> | |
| </svg> | |
| </div> | |
| <div> | |
| <div style="font-size:1.6em;font-weight:700;color:#111;letter-spacing:-0.5px;line-height:1.1;">Cardio<span style="color:#c1121f;">Lab</span> AI</div> | |
| <div style="font-size:0.6em;color:#9ca3af;margin-top:1px;">SJSU Biomedical Engineering</div> | |
| </div> | |
| </div> | |
| <svg width="80" height="22" viewBox="0 0 100 22" style="transform:scaleX(-1);"> | |
| <polyline points="0,11 18,11 23,3 27,19 31,1 35,17 39,11 100,11" fill="none" stroke="#c1121f" stroke-width="2" stroke-linecap="round" stroke-dasharray="400" style="animation:ecg 1.8s ease forwards;"/> | |
| </svg> | |
| </div> | |
| <div style="display:flex;gap:6px;align-items:center;"> | |
| <span style="background:#fef2f2;border:1px solid #fecaca;color:#c1121f;padding:3px 10px;border-radius:20px;font-size:0.65em;font-weight:600;">RAG Active</span> | |
| <span style="background:#eff6ff;border:1px solid #bfdbfe;color:#1d4ed8;padding:3px 10px;border-radius:20px;font-size:0.65em;font-weight:600;">4 Models</span> | |
| <span style="background:#f0fdf4;border:1px solid #bbf7d0;color:#15803d;padding:3px 10px;border-radius:20px;font-size:0.65em;font-weight:600;">16 Papers</span> | |
| </div> | |
| </div> | |
| """ | |
| # ── PAPER DATABASE ───────────────────────────────────────── | |
| CHUNKS = [] | |
| METADATA = [] | |
| EMBEDDINGS = None | |
| PAPERS_LOADED = False | |
| EMBEDDER = None | |
| def load_papers(): | |
| global CHUNKS, METADATA, EMBEDDINGS, PAPERS_LOADED, EMBEDDER | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| chunks_path = hf_hub_download(repo_id=PAPERS_DB_REPO, filename="chunks.json", repo_type="dataset", token=HF_TOKEN) | |
| meta_path = hf_hub_download(repo_id=PAPERS_DB_REPO, filename="metadata.json", repo_type="dataset", token=HF_TOKEN) | |
| emb_path = hf_hub_download(repo_id=PAPERS_DB_REPO, filename="embeddings.npy", repo_type="dataset", token=HF_TOKEN) | |
| with open(chunks_path) as f: CHUNKS = json.load(f) | |
| with open(meta_path) as f: METADATA = json.load(f) | |
| EMBEDDINGS = np.load(emb_path) | |
| EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2") | |
| PAPERS_LOADED = True | |
| print("Papers loaded: " + str(len(CHUNKS)) + " chunks") | |
| except Exception as e: | |
| print("Paper load error: " + str(e)) | |
| load_papers() | |
| def search_papers(query, n=4): | |
| if not PAPERS_LOADED or EMBEDDINGS is None or EMBEDDER is None: return "", [] | |
| try: | |
| q_emb = EMBEDDER.encode([query]) | |
| norms = np.linalg.norm(EMBEDDINGS, axis=1, keepdims=True) | |
| emb_norm = EMBEDDINGS / (norms + 1e-10) | |
| q_norm = q_emb / (np.linalg.norm(q_emb) + 1e-10) | |
| scores = (emb_norm @ q_norm.T).flatten() | |
| top_idx = np.argsort(scores)[::-1][:n] | |
| context = ""; results = []; seen = set() | |
| for idx in top_idx: | |
| chunk = CHUNKS[idx]; meta = METADATA[idx]; score = float(scores[idx]) | |
| if score > 0.25: | |
| results.append({"chunk": chunk, "paper": meta["paper"], "score": score}) | |
| if meta["paper"] not in seen: | |
| context += chr(10) + "=== FROM: " + meta["paper"] + " ===" + chr(10) | |
| seen.add(meta["paper"]) | |
| context += chunk[:500] + chr(10) | |
| return context, results | |
| except: return "", [] | |
| # ── SESSION MANAGEMENT ───────────────────────────────────── | |
| def load_all_sessions(): | |
| if not HF_TOKEN: return {} | |
| try: | |
| path = hf_hub_download(repo_id=HISTORY_REPO, filename="chat_history.json", repo_type="dataset", token=HF_TOKEN) | |
| with open(path) as f: return json.load(f) | |
| except: return {} | |
| def save_all_sessions(sessions): | |
| if not HF_TOKEN: return False | |
| try: | |
| api2 = HfApi(token=HF_TOKEN) | |
| api2.upload_file(path_or_fileobj=json.dumps(sessions, indent=2).encode(), | |
| path_in_repo="chat_history.json", repo_id=HISTORY_REPO, | |
| repo_type="dataset", token=HF_TOKEN, commit_message="Update") | |
| return True | |
| except: return False | |
| def get_session_list(): | |
| s = load_all_sessions() | |
| return list(reversed(list(s.keys()))) if s else ["No saved sessions"] | |
| def save_session(history, name): | |
| if not history: return "Nothing to save", gr.update() | |
| if not name or not name.strip(): name = "Chat " + datetime.now().strftime("%b %d %H:%M") | |
| sessions = load_all_sessions() | |
| sessions[name] = {"messages": history, "saved_at": datetime.now().isoformat()} | |
| ok = save_all_sessions(sessions) | |
| choices = get_session_list() | |
| return ("Saved: " + name if ok else "Save failed"), gr.update(choices=choices, value=name) | |
| def load_session(name): | |
| if not name or "No saved" in name: return [], "Select a session" | |
| sessions = load_all_sessions() | |
| return (sessions[name]["messages"], "Loaded: " + name) if name in sessions else ([], "Not found") | |
| def delete_session(name): | |
| if not name or "No saved" in name: return "Select a session", gr.update() | |
| sessions = load_all_sessions() | |
| if name in sessions: | |
| del sessions[name]; save_all_sessions(sessions) | |
| choices = get_session_list() | |
| return "Deleted: " + name, gr.update(choices=choices, value=choices[0] if choices else None) | |
| return "Not found", gr.update() | |
| def new_chat(): return [], "", "New chat" | |
| # ── SEARCH ───────────────────────────────────────────────── | |
| def get_pubmed(query, n=3): | |
| try: | |
| r = requests.get("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", | |
| params={"db":"pubmed","term":query+" AND (heart valve OR hemodynamics OR microfluidic OR thrombogen OR creatinine OR CKD)","retmax":n,"retmode":"json","sort":"date","field":"tiab"},timeout=10) | |
| ids = r.json()["esearchresult"]["idlist"] | |
| return chr(10).join(["https://pubmed.ncbi.nlm.nih.gov/"+i for i in ids]) if ids else "" | |
| except: return "" | |
| def quick_search(query): | |
| if not query.strip(): return "Please enter a topic." | |
| try: | |
| expanded = query | |
| if GROQ_KEY: | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| resp = client.chat.completions.create(model="llama-3.1-8b-instant", | |
| messages=[{"role":"system","content":"Biomedical PubMed expert. Convert to MeSH terms. Return ONLY terms."}, | |
| {"role":"user","content":"Optimize: " + query}], max_tokens=60) | |
| expanded = resp.choices[0].message.content.strip() or query | |
| except: pass | |
| forced = expanded + " AND (heart valve OR hemodynamics OR microfluidic OR thrombogen OR creatinine OR PIV OR CKD)" | |
| r = requests.get("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", | |
| params={"db":"pubmed","term":forced,"retmax":8,"retmode":"json","sort":"date","field":"tiab"},timeout=12) | |
| ids = r.json()["esearchresult"]["idlist"] | |
| out = "QUERY: " + query + chr(10) + "="*40 + chr(10) + chr(10) | |
| if ids: | |
| r2 = requests.get("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi", | |
| params={"db":"pubmed","id":",".join(ids),"retmode":"xml","rettype":"abstract"},timeout=12) | |
| import xml.etree.ElementTree as ET | |
| root = ET.fromstring(r2.content) | |
| out += "PUBMED:" + chr(10) | |
| for article in root.findall(".//PubmedArticle"): | |
| try: | |
| title = article.find(".//ArticleTitle").text or "No title" | |
| pmid = article.find(".//PMID").text or "" | |
| year_el = article.find(".//PubDate/Year") | |
| year = year_el.text if year_el is not None else "" | |
| out += str(title)[:85] + " (" + year + ")" + chr(10) | |
| out += " https://pubmed.ncbi.nlm.nih.gov/" + pmid + chr(10) + chr(10) | |
| except: continue | |
| try: | |
| r3 = requests.get("https://api.semanticscholar.org/graph/v1/paper/search", | |
| params={"query":expanded,"limit":5,"fields":"title,year,url,citationCount"},timeout=12) | |
| papers = r3.json().get("data",[]) | |
| out += "SEMANTIC SCHOLAR:" + chr(10) | |
| for p in papers: | |
| year = p.get("year",0) or 0 | |
| if int(year) >= 2015: | |
| out += p.get("title","")[:85] + " (" + str(year) + ")" | |
| cites = p.get("citationCount",0) | |
| if cites: out += " | " + str(cites) + " citations" | |
| out += chr(10) + " " + p.get("url","") + chr(10) + chr(10) | |
| except: pass | |
| out += "SJSU SCHOLARWORKS:" + chr(10) | |
| out += " https://scholarworks.sjsu.edu/do/search/?q=" + requests.utils.quote(query) + "&context=6781027" | |
| return out | |
| except Exception as e: | |
| return "Search error: " + str(e) | |
| # ── CHAT ─────────────────────────────────────────────────── | |
| def research_chat(message, history, chat_model): | |
| if not message.strip(): return "", history | |
| if not GROQ_KEY: | |
| history.append({"role":"user","content":message}) | |
| history.append({"role":"assistant","content":"Error: Add GROQ_API_KEY to Space Settings."}) | |
| return "", history | |
| try: | |
| model_id = CHAT_MODELS.get(chat_model, "llama-3.3-70b-versatile") | |
| client = Groq(api_key=GROQ_KEY) | |
| paper_context, paper_results = search_papers(message, n=4) | |
| if paper_context: | |
| system_prompt = ("You are CardioLab AI for SJSU Biomedical Engineering. " | |
| "Answer using SJSU CardioLab research papers below. Cite paper names." + | |
| chr(10) + "SJSU PAPERS:" + chr(10) + paper_context + chr(10) + "KNOWLEDGE: " + KNOWHOW) | |
| else: | |
| system_prompt = "You are CardioLab AI for SJSU Biomedical Engineering. " + KNOWHOW | |
| msgs = [{"role":"system","content":system_prompt}] | |
| for item in history: | |
| if isinstance(item, dict): msgs.append({"role":item["role"],"content":item["content"]}) | |
| msgs.append({"role":"user","content":message}) | |
| resp = client.chat.completions.create(model=model_id, messages=msgs, max_tokens=800) | |
| answer = resp.choices[0].message.content | |
| if paper_results: | |
| unique_papers = list(dict.fromkeys([r["paper"] for r in paper_results])) | |
| answer += chr(10) + chr(10) + "Sources:" | |
| for p in unique_papers[:3]: | |
| answer += chr(10) + " - " + p.replace(".pdf","").replace("_"," ") | |
| pubmed = get_pubmed(message, n=2) | |
| if pubmed: answer += chr(10) + "PubMed: " + pubmed | |
| history.append({"role":"user","content":message}) | |
| history.append({"role":"assistant","content":answer}) | |
| return "", history | |
| except Exception as e: | |
| history.append({"role":"user","content":message}) | |
| history.append({"role":"assistant","content":"Error: " + str(e)}) | |
| return "", history | |
| def voice_chat(audio, history): | |
| if audio is None: | |
| history.append({"role":"assistant","content":"Please record first."}) | |
| return history | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| with open(audio, "rb") as f: | |
| tx = client.audio.transcriptions.create(file=("audio.wav", f, "audio/wav"), model="whisper-large-v3") | |
| paper_context, _ = search_papers(tx.text, n=3) | |
| system = "You are CardioLab AI. " + KNOWHOW | |
| if paper_context: system = "You are CardioLab AI. Use these SJSU papers:" + chr(10) + paper_context + chr(10) + KNOWHOW | |
| msgs = [{"role":"system","content":system}] | |
| for item in history: | |
| if isinstance(item, dict): msgs.append({"role":item["role"],"content":item["content"]}) | |
| msgs.append({"role":"user","content":tx.text}) | |
| resp = client.chat.completions.create(model="llama-3.3-70b-versatile", messages=msgs, max_tokens=500) | |
| history.append({"role":"user","content":"Voice: " + tx.text}) | |
| history.append({"role":"assistant","content":resp.choices[0].message.content}) | |
| return history | |
| except Exception as e: | |
| history.append({"role":"assistant","content":"Voice error: " + str(e)}) | |
| return history | |
| # ── PHASE D ──────────────────────────────────────────────── | |
| def generate_protocol(experiment_type, specific_params): | |
| if not GROQ_KEY: return "Error: Add GROQ_API_KEY" | |
| if not experiment_type: return "Select experiment type" | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| paper_context, _ = search_papers(experiment_type, n=4) | |
| lab_ctx = { | |
| "MCL": "Sylgard 184 PDMS 10:1 ratio 48hr cure. Tygon tubing. 70bpm 5L/min.", | |
| "PIV": "Green laser 532nm. Normal velocity 0.5-2.0 m/s. Shear below 5 Pa.", | |
| "Thrombogenicity": "Arduino Uno stepper motor 48V. 150mL fresh blood. Sample 0 20 40 60 min. Heska HT5. TAT below 8 ng/mL. PF1.2 below 2.0 nmol/L.", | |
| "uPAD": "Whatman filter paper. Wax printer 120C. Jaffe reaction picric acid.", | |
| "FSI": "COMSOL ALE mesh. Blood 1060 kg/m3 0.0035 Pa.s.", | |
| } | |
| extra = next((v for k, v in lab_ctx.items() if k.lower() in experiment_type.lower()), "") | |
| system_msg = ("You are CardioLab AI protocol generator for SJSU. Generate COMPLETE protocol with: " | |
| "1.OBJECTIVE 2.MATERIALS AND EQUIPMENT 3.SAFETY 4.PROCEDURE 5.DATA COLLECTION " | |
| "6.ANALYSIS 7.EXPECTED RESULTS with normal ranges 8.TROUBLESHOOTING. " | |
| "Use exact SJSU CardioLab values.") | |
| user_msg = "Generate protocol for: " + experiment_type | |
| if specific_params and specific_params.strip(): user_msg += chr(10) + "Parameters: " + specific_params | |
| if extra: user_msg += chr(10) + "Context: " + extra | |
| if paper_context: user_msg += chr(10) + "SJSU papers: " + paper_context[:600] | |
| resp = client.chat.completions.create(model="llama-3.3-70b-versatile", | |
| messages=[{"role":"system","content":system_msg},{"role":"user","content":user_msg}], max_tokens=1200) | |
| return resp.choices[0].message.content | |
| except Exception as e: return "Error: " + str(e) | |
| def generate_report(data_description, experiment_type, results): | |
| if not GROQ_KEY: return "Error: Add GROQ_API_KEY" | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| paper_context, _ = search_papers(experiment_type, n=3) | |
| system_msg = ("You are CardioLab AI report writer for SJSU. Generate professional research report with: " | |
| "1.ABSTRACT 2.INTRODUCTION 3.MATERIALS AND METHODS 4.RESULTS AND DISCUSSION " | |
| "5.CONCLUSION 6.RECOMMENDATIONS 7.REFERENCES. Academic style.") | |
| user_msg = "Write report for: " + experiment_type | |
| if data_description and data_description.strip(): user_msg += chr(10) + "Description: " + data_description | |
| if results and results.strip(): user_msg += chr(10) + "Results: " + results | |
| if paper_context: user_msg += chr(10) + "SJSU papers: " + paper_context[:600] | |
| resp = client.chat.completions.create(model="llama-3.3-70b-versatile", | |
| messages=[{"role":"system","content":system_msg},{"role":"user","content":user_msg}], max_tokens=1500) | |
| return resp.choices[0].message.content | |
| except Exception as e: return "Error: " + str(e) | |
| def generate_hypothesis(research_area, current_findings): | |
| if not GROQ_KEY: return "Error: Add GROQ_API_KEY" | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| paper_context, _ = search_papers(research_area, n=3) | |
| system_msg = ("You are CardioLab AI research assistant for SJSU. Generate 3 testable hypotheses. " | |
| "For each: H0 null, H1 alternative, rationale, suggested experiment, expected outcome.") | |
| user_msg = "Hypotheses for: " + research_area | |
| if current_findings and current_findings.strip(): user_msg += chr(10) + "Findings: " + current_findings | |
| if paper_context: user_msg += chr(10) + "SJSU papers: " + paper_context[:500] | |
| resp = client.chat.completions.create(model="llama-3.3-70b-versatile", | |
| messages=[{"role":"system","content":system_msg},{"role":"user","content":user_msg}], max_tokens=1000) | |
| return resp.choices[0].message.content | |
| except Exception as e: return "Error: " + str(e) | |
| # ── ANALYSIS TOOLS ───────────────────────────────────────── | |
| def analyze_upad_photo(image): | |
| if image is None: return None, "Upload a uPAD photo first." | |
| try: | |
| img = Image.fromarray(image) if not isinstance(image, Image.Image) else image | |
| arr = np.array(img); h, w = arr.shape[:2] | |
| y1, y2, x1, x2 = int(h*0.35), int(h*0.65), int(w*0.35), int(w*0.65) | |
| zone = arr[y1:y2, x1:x2] | |
| R = float(np.mean(zone[:,:,0])); G = float(np.mean(zone[:,:,1])); B = float(np.mean(zone[:,:,2])) | |
| c = max(0, round(0.018*(R-B)-0.3, 2)) | |
| if c < 1.2: s, a = "Normal", "Monitor annually." | |
| elif c < 1.5: s, a = "Borderline", "Repeat in 3 months." | |
| elif c < 3.0: s, a = "Stage 2 CKD", "Consult nephrologist." | |
| elif c < 6.0: s, a = "Stage 3-4 CKD", "Immediate consultation." | |
| else: s, a = "Stage 5 CKD", "Emergency care." | |
| ri = img.copy() | |
| import PIL.ImageDraw as D; D.Draw(ri).rectangle([x1, y1, x2, y2], outline=(0,255,0), width=3) | |
| return ri, ("R:" + str(round(R,1)) + " G:" + str(round(G,1)) + " B:" + str(round(B,1)) + chr(10) + | |
| "Creatinine: " + str(c) + " mg/dL" + chr(10) + "Stage: " + s + chr(10) + "Action: " + a) | |
| except Exception as e: return None, "Error: " + str(e) | |
| def mk_chart(fn, title, bg, fg, gc, ac, pb): | |
| fig2, ax = plt.subplots(figsize=(8,5)); fig2.patch.set_facecolor(bg); ax.set_facecolor(pb) | |
| fn(ax); ax.set_title(title, color=fg, fontweight="bold", fontsize=13, pad=8) | |
| ax.tick_params(colors=ac, labelsize=10); ax.grid(True, alpha=0.3, color=gc, linestyle="--") | |
| for sp in ["top","right"]: ax.spines[sp].set_visible(False) | |
| for sp in ["bottom","left"]: ax.spines[sp].set_color(gc) | |
| plt.tight_layout(); buf = io.BytesIO() | |
| plt.savefig(buf, format="png", facecolor=bg, bbox_inches="tight", dpi=130); buf.seek(0) | |
| res = Image.open(buf).copy(); plt.close(); return res | |
| def analyze_piv_csv(file, theme="White"): | |
| if file is None: return None, None, None, None, "Upload PIV CSV first." | |
| try: | |
| df = pd.read_csv(file.name); cols = [c.lower().strip() for c in df.columns]; df.columns = cols | |
| num_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
| if not num_cols: return None, None, None, None, "No numeric columns." | |
| bg = "#fff" if theme=="White" else "#0a1628"; fg = "#1a202c" if theme=="White" else "white" | |
| gc = "#e2e8f0" if theme=="White" else "#2d4a8a"; ac = "#4a5568" if theme=="White" else "#a8b2d8" | |
| pb = "#f7fafc" if theme=="White" else "#132340" | |
| x = np.arange(len(df)) | |
| vc = next((c for c in cols if any(k in c for k in ["vel","speed","v_mag"])), num_cols[0] if num_cols else None) | |
| sc2 = next((c for c in cols if any(k in c for k in ["shear","stress","tau","wss"])), num_cols[1] if len(num_cols)>1 else None) | |
| tc = next((c for c in cols if "time" in c or "frame" in c), None); xv = df[tc] if tc else x | |
| def pv(ax): | |
| if vc: | |
| ax.plot(xv, df[vc], color="#c1121f", linewidth=2.5, marker="o", markersize=5) | |
| ax.fill_between(xv, df[vc], alpha=0.15, color="#c1121f") | |
| ax.axhline(y=2.0, color="#f59e0b", linestyle="--", linewidth=2, label="Risk 2.0 m/s") | |
| ax.set_ylabel("Velocity (m/s)", color=ac); ax.legend(fontsize=9, labelcolor=fg, facecolor=pb) | |
| def ps(ax): | |
| if sc2: | |
| xp = xv.values if tc else x | |
| ax.plot(xp, df[sc2], color="#0057a8", linewidth=2.5, marker="s", markersize=5) | |
| ax.fill_between(xp, df[sc2], alpha=0.15, color="#0057a8") | |
| ax.axhline(y=5, color="#f59e0b", linestyle="--", linewidth=2, label="Caution 5 Pa") | |
| ax.axhline(y=10, color="#c1121f", linestyle="--", linewidth=2, label="Risk 10 Pa") | |
| ax.set_ylabel("Shear (Pa)", color=ac); ax.legend(fontsize=9, labelcolor=fg, facecolor=pb) | |
| def psc(ax): | |
| if vc and sc2: | |
| s3 = ax.scatter(df[vc], df[sc2], c=x, cmap="RdYlGn_r", s=90, edgecolors=fg, linewidth=0.5, zorder=5) | |
| cb = plt.colorbar(s3, ax=ax, label="Time"); cb.ax.yaxis.label.set_color(fg); cb.ax.tick_params(colors=ac) | |
| ax.axvline(x=2.0, color="#f59e0b", linestyle="--", linewidth=2); ax.axhline(y=10, color="#c1121f", linestyle="--", linewidth=2) | |
| ax.set_xlabel("Velocity (m/s)", color=ac); ax.set_ylabel("Shear (Pa)", color=ac) | |
| def psum(ax): | |
| ax.axis("off"); risk = [] | |
| st = "CLINICAL SUMMARY" + chr(10) + "="*20 + chr(10) + chr(10) | |
| for col in num_cols[:3]: | |
| mn = round(df[col].mean(), 3); mx = round(df[col].max(), 3) | |
| st += col[:14] + ":" + chr(10) + " Mean: " + str(mn) + chr(10) + " Max: " + str(mx) + chr(10) + chr(10) | |
| if "vel" in col and mx > 2.0: risk.append("HIGH VELOCITY") | |
| if "shear" in col and mx > 10: risk.append("HIGH SHEAR") | |
| bc = "#c1121f" if risk else "#2ecc71" | |
| st += "="*20 + chr(10) + ("OVERALL: HIGH RISK" if risk else "OVERALL: LOW RISK") | |
| ax.text(0.05, 0.97, st, transform=ax.transAxes, color=fg, fontsize=10, va="top", | |
| fontfamily="monospace", bbox=dict(boxstyle="round,pad=0.8", facecolor=pb, edgecolor=bc, linewidth=2.5)) | |
| i1 = mk_chart(pv, "Velocity Profile", bg, fg, gc, ac, pb) | |
| i2 = mk_chart(ps, "Wall Shear Stress", bg, fg, gc, ac, pb) | |
| i3 = mk_chart(psc, "Velocity vs Shear", bg, fg, gc, ac, pb) | |
| i4 = mk_chart(psum, "Clinical Summary", bg, fg, gc, ac, pb) | |
| ai = "" | |
| if GROQ_KEY: | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| resp = client.chat.completions.create(model="llama-3.3-70b-versatile", | |
| messages=[{"role":"system","content":"PIV expert SJSU CardioLab."}, | |
| {"role":"user","content":"PIV from 27mm SJM Regent:" + chr(10) + df.describe().to_string()[:500]}], max_tokens=250) | |
| ai = chr(10) + "AI: " + resp.choices[0].message.content | |
| except: pass | |
| return i1, i2, i3, i4, "PIV: " + str(len(df)) + " rows" + ai | |
| except Exception as e: return None, None, None, None, "Error: " + str(e) | |
| def analyze_tgt_csv(file, theme="White"): | |
| if file is None: return None, None, None, None, "Upload TGT CSV first." | |
| try: | |
| df = pd.read_csv(file.name); cols = [c.lower().strip() for c in df.columns]; df.columns = cols | |
| num_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
| bg = "#fff" if theme=="White" else "#0a1628"; fg = "#1a202c" if theme=="White" else "white" | |
| gc = "#e2e8f0" if theme=="White" else "#2d4a8a"; ac = "#4a5568" if theme=="White" else "#a8b2d8" | |
| pb = "#f7fafc" if theme=="White" else "#132340" | |
| tc = next((c for c in cols if "time" in c or "min" in c), None) | |
| tatc = next((c for c in cols if "tat" in c), num_cols[0] if num_cols else None) | |
| pfc = next((c for c in cols if "pf" in c), num_cols[1] if len(num_cols)>1 else None) | |
| hc = next((c for c in cols if "hemo" in c), num_cols[2] if len(num_cols)>2 else None) | |
| plc = next((c for c in cols if "platelet" in c or "plt" in c), num_cols[3] if len(num_cols)>3 else None) | |
| def mk2(dc, color, yl, lim, ll, title, bar=False): | |
| def fn(ax): | |
| if dc and dc in df.columns: | |
| xp = df[tc].values if tc else range(len(df)); yp = df[dc].values | |
| if bar: | |
| bs = ax.bar(range(len(yp)), yp, color=color, alpha=0.85, edgecolor=bg, width=0.6) | |
| for b, v in zip(bs, yp): ax.text(b.get_x()+b.get_width()/2, b.get_height()+0.5, str(round(v,1)), ha="center", va="bottom", color=fg, fontsize=10, fontweight="bold") | |
| else: | |
| ax.plot(xp, yp, color=color, linewidth=3, marker="o", markersize=8) | |
| ax.fill_between(xp, yp, alpha=0.15, color=color) | |
| for xi, yi in zip(xp, yp): ax.annotate(str(round(yi,1)), (xi, yi), textcoords="offset points", xytext=(0,10), ha="center", color=fg, fontsize=10, fontweight="bold") | |
| ax.axhline(y=lim, color="#f59e0b", linestyle="--", linewidth=2.5, label=ll) | |
| ax.legend(fontsize=10, labelcolor=fg, facecolor=pb); ax.set_ylabel(yl, color=ac) | |
| mv = round(float(np.max(yp)), 2) | |
| ax.set_title(title + chr(10) + "Max: " + str(mv) + " - " + ("HIGH" if mv>lim else "NORMAL"), color=fg, fontweight="bold", fontsize=12) | |
| return mk_chart(fn, title, bg, fg, gc, ac, pb) | |
| i1 = mk2(tatc, "#c1121f", "TAT (ng/mL)", 8, "Normal: 8", "TAT") | |
| i2 = mk2(pfc, "#0057a8", "PF1.2", 2.0, "Normal: 2.0", "PF1.2") | |
| i3 = mk2(hc, "#2ecc71", "Free Hgb (mg/L)", 20, "Normal: 20", "Free Hemoglobin", bar=True) | |
| i4 = mk2(plc, "#e8a020", "Platelets", 150, "Normal>150", "Platelets") | |
| ai = "" | |
| if GROQ_KEY: | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| resp = client.chat.completions.create(model="llama-3.3-70b-versatile", | |
| messages=[{"role":"system","content":"Hematology expert. Thrombogenicity risk."}, | |
| {"role":"user","content":"TGT:" + chr(10) + df.describe().to_string()[:500]}], max_tokens=250) | |
| ai = chr(10) + "AI: " + resp.choices[0].message.content | |
| except: pass | |
| return i1, i2, i3, i4, "TGT: " + str(len(df)) + " rows" + ai | |
| except Exception as e: return None, None, None, None, "Error: " + str(e) | |
| def generate_image(prompt): | |
| if not prompt.strip(): return None, "Enter description.", "" | |
| if not HF_TOKEN: return None, "Add HF_TOKEN.", "" | |
| try: | |
| enhanced, desc = prompt, "" | |
| if GROQ_KEY: | |
| try: | |
| client = Groq(api_key=GROQ_KEY) | |
| resp = client.chat.completions.create(model="llama-3.3-70b-versatile", | |
| messages=[{"role":"system","content":"Format: DESCRIPTION: [2 sentences] PROMPT: [detailed image prompt]"}, | |
| {"role":"user","content":"Biomedical image: " + prompt}], max_tokens=200) | |
| full = resp.choices[0].message.content | |
| if "DESCRIPTION:" in full and "PROMPT:" in full: | |
| desc = full.split("DESCRIPTION:")[1].split("PROMPT:")[0].strip() | |
| enhanced = full.split("PROMPT:")[1].strip() | |
| except: pass | |
| headers = {"Authorization": "Bearer " + HF_TOKEN, "Content-Type": "application/json"} | |
| for url in ["https://router.huggingface.co/hf-inference/models/black-forest-labs/FLUX.1-schnell", | |
| "https://router.huggingface.co/hf-inference/models/stabilityai/stable-diffusion-xl-base-1.0"]: | |
| try: | |
| r = requests.post(url, headers=headers, json={"inputs":enhanced,"parameters":{"num_inference_steps":8}}, timeout=60) | |
| if r.status_code == 200: return Image.open(io.BytesIO(r.content)), "Generated!", desc | |
| except: continue | |
| return None, "Models busy.", desc | |
| except Exception as e: return None, "Error: " + str(e), "" | |
| def piv_manual(v, s, h): | |
| vr = "HIGH-stenosis" if float(v)>2.0 else "NORMAL" | |
| sr = "HIGH-thrombosis" if float(s)>10 else "ELEVATED" if float(s)>5 else "NORMAL" | |
| return "Velocity: " + str(v) + " m/s - " + vr + chr(10) + "Shear: " + str(s) + " Pa - " + sr + chr(10) + "HR: " + str(h) + " bpm" | |
| def tgt_manual(t, p, h, pl, tm): | |
| risk = sum([float(t)>15, float(p)>2.0, float(h)>50, float(pl)<150]) | |
| return "TAT:" + str(t) + " PF1.2:" + str(p) + chr(10) + "Hemo:" + str(h) + " Plt:" + str(pl) + chr(10) + ("HIGH RISK" if risk>=3 else "MODERATE" if risk>=2 else "LOW RISK") | |
| # ── UI ───────────────────────────────────────────────────── | |
| with gr.Blocks(title="CardioLab AI - SJSU", css=CSS) as demo: | |
| gr.HTML(HEADER) | |
| gr.HTML("""<div style="background:#f0fdf4;border:1px solid #bbf7d0;border-radius:8px;padding:8px 16px;margin:6px 0;text-align:center;"> | |
| <span style="color:#166534;font-size:0.8em;font-weight:500;">RAG Active: 417 chunks from 16 SJSU papers · Fine-tuned Model · Select model using radio buttons in Chat tab</span></div>""") | |
| with gr.Tabs(): | |
| with gr.Tab("Chat"): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=200): | |
| gr.HTML("""<div style="background:#fef2f2;border:1px solid #fecaca;border-radius:10px;padding:12px;margin-bottom:8px;"> | |
| <div style="display:flex;align-items:center;gap:6px;margin-bottom:3px;"> | |
| <svg width="12" height="11" viewBox="0 0 100 90"><path d="M50 85 C50 85 5 55 5 30 C5 15 18 5 30 5 C38 5 45 9 50 15 C55 9 62 5 70 5 C82 5 95 15 95 30 C95 55 50 85 50 85Z" fill="#c1121f"/></svg> | |
| <span style="color:#c1121f;font-weight:700;font-size:0.82em;">CardioLab</span></div> | |
| <div style="color:#9ca3af;font-size:0.7em;">Conversations</div></div>""") | |
| new_chat_btn = gr.Button("+ New Chat", variant="secondary") | |
| session_dropdown = gr.Dropdown(choices=get_session_list(), label="Saved Sessions", interactive=True) | |
| load_btn = gr.Button("Load Session", variant="primary") | |
| session_name_box = gr.Textbox(placeholder="Name this session...", label="Session Name", lines=1) | |
| with gr.Row(): | |
| save_btn = gr.Button("Save", variant="primary", scale=2) | |
| delete_btn = gr.Button("Del", variant="secondary", scale=1) | |
| session_status = gr.Textbox(label="", lines=1, interactive=False, container=False) | |
| with gr.Column(scale=4): | |
| chat_model_radio = gr.Radio( | |
| choices=list(CHAT_MODELS.keys()), | |
| value="Llama 3.3 70B (Best)", | |
| label="Select AI Model", | |
| container=True | |
| ) | |
| chatbot = gr.Chatbot(label="", height=400, show_label=False, container=False) | |
| with gr.Row(): | |
| msg_box = gr.Textbox(placeholder="Ask anything — AI searches 16 SJSU papers + PubMed...", label="", lines=2, scale=5, container=False) | |
| with gr.Column(scale=1, min_width=80): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| send_btn.click(research_chat, inputs=[msg_box, chatbot, chat_model_radio], outputs=[msg_box, chatbot]) | |
| msg_box.submit(research_chat, inputs=[msg_box, chatbot, chat_model_radio], outputs=[msg_box, chatbot]) | |
| clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg_box]) | |
| new_chat_btn.click(new_chat, outputs=[chatbot, msg_box, session_status]) | |
| save_btn.click(save_session, inputs=[chatbot, session_name_box], outputs=[session_status, session_dropdown]) | |
| load_btn.click(load_session, inputs=session_dropdown, outputs=[chatbot, session_status]) | |
| delete_btn.click(delete_session, inputs=session_dropdown, outputs=[session_status, session_dropdown]) | |
| with gr.Tab("Voice"): | |
| voice_chatbot = gr.Chatbot(label="", height=360, show_label=False) | |
| audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Question") | |
| with gr.Row(): | |
| voice_btn = gr.Button("Ask by Voice", variant="primary") | |
| voice_clear = gr.Button("Clear", variant="secondary") | |
| voice_btn.click(voice_chat, inputs=[audio_input, voice_chatbot], outputs=voice_chatbot) | |
| voice_clear.click(lambda: [], outputs=voice_chatbot) | |
| with gr.Tab("Papers"): | |
| gr.Markdown("### Search PubMed + Semantic Scholar + SJSU ScholarWorks") | |
| with gr.Row(): | |
| search_input = gr.Textbox(placeholder="e.g. bileaflet mechanical heart valve thrombogenicity hemodynamics", label="Research Topic", scale=4) | |
| search_btn = gr.Button("Search", variant="primary", scale=1) | |
| search_output = gr.Textbox(label="Results", lines=22) | |
| search_btn.click(quick_search, inputs=search_input, outputs=search_output) | |
| search_input.submit(quick_search, inputs=search_input, outputs=search_output) | |
| with gr.Tab("PIV CSV"): | |
| with gr.Row(): | |
| piv_file = gr.File(label="Upload PIV CSV", file_types=[".csv"], scale=3) | |
| piv_theme = gr.Radio(["White","Dark"], value="White", label="Theme", scale=1) | |
| piv_btn = gr.Button("Analyze PIV Data", variant="primary") | |
| piv_result = gr.Textbox(label="AI Analysis", lines=4) | |
| with gr.Row(): | |
| piv_c1 = gr.Image(label="Velocity Profile", type="pil") | |
| piv_c2 = gr.Image(label="Shear Stress", type="pil") | |
| with gr.Row(): | |
| piv_c3 = gr.Image(label="Velocity vs Shear", type="pil") | |
| piv_c4 = gr.Image(label="Clinical Summary", type="pil") | |
| piv_btn.click(analyze_piv_csv, inputs=[piv_file, piv_theme], outputs=[piv_c1, piv_c2, piv_c3, piv_c4, piv_result]) | |
| with gr.Tab("TGT CSV"): | |
| with gr.Row(): | |
| tgt_file = gr.File(label="Upload TGT CSV", file_types=[".csv"], scale=3) | |
| tgt_theme = gr.Radio(["White","Dark"], value="White", label="Theme", scale=1) | |
| tgt_btn = gr.Button("Analyze TGT Data", variant="primary") | |
| tgt_result = gr.Textbox(label="AI Assessment", lines=4) | |
| with gr.Row(): | |
| tgt_c1 = gr.Image(label="TAT", type="pil"); tgt_c2 = gr.Image(label="PF1.2", type="pil") | |
| with gr.Row(): | |
| tgt_c3 = gr.Image(label="Hemoglobin", type="pil"); tgt_c4 = gr.Image(label="Platelets", type="pil") | |
| tgt_btn.click(analyze_tgt_csv, inputs=[tgt_file, tgt_theme], outputs=[tgt_c1, tgt_c2, tgt_c3, tgt_c4, tgt_result]) | |
| with gr.Tab("uPAD"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| photo_input = gr.Image(label="Upload uPAD Photo", type="numpy", height=260) | |
| analyze_btn = gr.Button("Analyze uPAD Photo", variant="primary") | |
| with gr.Column(): | |
| photo_img = gr.Image(label="Detection Zone", type="pil", height=260) | |
| photo_text = gr.Textbox(label="CKD Result", lines=8) | |
| analyze_btn.click(analyze_upad_photo, inputs=photo_input, outputs=[photo_img, photo_text]) | |
| gr.Markdown("**Manual RGB:**") | |
| with gr.Row(): | |
| r = gr.Number(label="R", value=210); g = gr.Number(label="G", value=140); b = gr.Number(label="B", value=80) | |
| out3 = gr.Textbox(label="Result", lines=3) | |
| gr.Button("Analyze RGB", variant="secondary").click( | |
| lambda r, g, b: "Creatinine: " + str(max(0,round(0.02*(r-b)-0.5,2))) + " mg/dL" + chr(10) + | |
| ("Normal" if max(0,round(0.02*(r-b)-0.5,2)) < 1.2 else "Borderline" if max(0,round(0.02*(r-b)-0.5,2)) < 1.5 else "CKD"), | |
| inputs=[r, g, b], outputs=out3) | |
| with gr.Tab("AI Image"): | |
| with gr.Row(): | |
| img_prompt = gr.Textbox(placeholder="e.g. 27mm bileaflet mechanical heart valve cross section", label="Describe image", lines=2, scale=4) | |
| with gr.Column(scale=1): | |
| img_btn = gr.Button("Generate", variant="primary") | |
| img_status = gr.Textbox(label="Status", lines=1) | |
| img_desc = gr.Textbox(label="AI Description", lines=2, interactive=False) | |
| img_output = gr.Image(label="Generated Image", type="pil", height=400) | |
| img_btn.click(generate_image, inputs=img_prompt, outputs=[img_output, img_status, img_desc]) | |
| with gr.Tab("PIV Manual"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| v = gr.Number(label="Max Velocity m/s", value=1.8) | |
| s = gr.Number(label="Wall Shear Pa", value=6.5) | |
| h = gr.Number(label="Heart Rate bpm", value=72) | |
| piv_out = gr.Textbox(label="Result", lines=4) | |
| gr.Button("Analyze PIV", variant="primary").click(piv_manual, inputs=[v, s, h], outputs=piv_out) | |
| with gr.Tab("TGT Manual"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| t1 = gr.Number(label="TAT ng/mL", value=18); t2 = gr.Number(label="PF1.2", value=2.5) | |
| t3 = gr.Number(label="Hemoglobin mg/L", value=60); t4 = gr.Number(label="Platelets", value=140) | |
| t5 = gr.Number(label="Time min", value=40); out2 = gr.Textbox(label="Result", lines=6) | |
| gr.Button("Analyze TGT", variant="primary").click(tgt_manual, inputs=[t1, t2, t3, t4, t5], outputs=out2) | |
| with gr.Tab("Protocol Generator"): | |
| gr.Markdown("### Generate complete lab protocols from SJSU CardioLab knowledge") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| proto_type = gr.Dropdown( | |
| choices=["MCL Setup", "PIV Experiment", "Thrombogenicity Tester Blood Clotting Test", | |
| "uPAD Fabrication", "uPAD Creatinine Test", "FSI COMSOL Simulation", "Valve Testing"], | |
| value="Thrombogenicity Tester Blood Clotting Test", label="Experiment Type") | |
| proto_params = gr.Textbox(placeholder="e.g. 27mm SJM valve 70bpm porcine blood", label="Specific Parameters", lines=2) | |
| proto_btn = gr.Button("Generate Protocol", variant="primary") | |
| with gr.Column(scale=2): | |
| proto_output = gr.Textbox(label="Generated Protocol", lines=28) | |
| proto_btn.click(generate_protocol, inputs=[proto_type, proto_params], outputs=proto_output) | |
| with gr.Tab("Report Writer"): | |
| gr.Markdown("### Generate professional research reports") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| report_exp = gr.Dropdown( | |
| choices=["MCL PIV Flow Analysis", "TGT Thrombogenicity Study", "uPAD CKD Detection", | |
| "FSI Simulation Study", "Heart Valve Comparison"], | |
| value="TGT Thrombogenicity Study", label="Study Type") | |
| report_desc = gr.Textbox(placeholder="e.g. TGT with 27mm SJM bileaflet at 70bpm 150mL porcine blood", label="Experiment Description", lines=3) | |
| report_results = gr.Textbox(placeholder="e.g. TAT=12.3 PF1.2=2.8 Hemo=45 Plt=142", label="Your Results", lines=2) | |
| report_btn = gr.Button("Generate Report", variant="primary") | |
| with gr.Column(scale=2): | |
| report_output = gr.Textbox(label="Generated Report", lines=28) | |
| report_btn.click(generate_report, inputs=[report_desc, report_exp, report_results], outputs=report_output) | |
| with gr.Tab("Hypothesis Generator"): | |
| gr.Markdown("### Generate testable research hypotheses") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| hyp_area = gr.Dropdown( | |
| choices=["Bileaflet MHV Thrombogenicity", "uPAD CKD Detection Accuracy", | |
| "PIV Flow Characterization", "FSI Simulation Validation", "Valve Design Comparison"], | |
| value="Bileaflet MHV Thrombogenicity", label="Research Area") | |
| hyp_findings = gr.Textbox(placeholder="Current observations from your experiments", label="Current Findings", lines=3) | |
| hyp_btn = gr.Button("Generate Hypotheses", variant="primary") | |
| with gr.Column(scale=2): | |
| hyp_output = gr.Textbox(label="Research Hypotheses", lines=25) | |
| hyp_btn.click(generate_hypothesis, inputs=[hyp_area, hyp_findings], outputs=hyp_output) | |
| gr.HTML("""<div style="text-align:center;padding:12px;border-top:1px solid #e2e8f0;background:#f8fafc;margin-top:8px;"> | |
| <span style="color:#9ca3af;font-size:0.72em;">CardioLab AI v40 · SJSU Biomedical Engineering · | |
| Inspired by <a href="https://github.com/snap-stanford/Biomni" style="color:#c1121f;text-decoration:none;">Biomni Stanford</a> | |
| · Apache 2.0 · $0 Cost</span></div>""") | |
| demo.launch() | |