| | import streamlit as st
|
| | import torch
|
| | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| | from peft import PeftModel
|
| | import os
|
| |
|
| |
|
| | st.set_page_config(
|
| | page_title="TinyLlama Medical Assistant",
|
| | page_icon="๐ฉบ",
|
| | layout="wide"
|
| | )
|
| |
|
| |
|
| | st.markdown("""
|
| | <style>
|
| | .main {background-color: #f0f2f6;}
|
| | .stButton>button {
|
| | width: 100%;
|
| | background-color: #4F46E5;
|
| | color: white;
|
| | }
|
| | .chat-message {
|
| | padding: 1.5rem;
|
| | border-radius: 0.5rem;
|
| | margin-bottom: 1rem;
|
| | display: flex;
|
| | flex-direction: column;
|
| | }
|
| | .chat-message.user {
|
| | background-color: #4F46E5;
|
| | color: white;
|
| | }
|
| | .chat-message.assistant {
|
| | background-color: white;
|
| | border: 1px solid #e5e7eb;
|
| | }
|
| | </style>
|
| | """, unsafe_allow_html=True)
|
| |
|
| |
|
| | USERS = {
|
| | "admin": "admin123",
|
| | "doctor": "doc123",
|
| | "student": "student123"
|
| | }
|
| |
|
| | MEDICAL_DISCLAIMER = """
|
| | โ ๏ธ **Medical Disclaimer:** This response is for educational purposes only and is not a substitute for professional medical advice. Always consult a qualified healthcare provider.
|
| | """
|
| |
|
| |
|
| | if "authenticated" not in st.session_state:
|
| | st.session_state.authenticated = False
|
| | if "messages" not in st.session_state:
|
| | st.session_state.messages = []
|
| | if "model_loaded" not in st.session_state:
|
| | st.session_state.model_loaded = False
|
| |
|
| |
|
| | if not st.session_state.authenticated:
|
| | col1, col2, col3 = st.columns([1, 2, 1])
|
| |
|
| | with col2:
|
| | st.title("๐ Medical Assistant Login")
|
| | st.markdown("---")
|
| |
|
| | username = st.text_input("Username", key="login_username")
|
| | password = st.text_input("Password", type="password", key="login_password")
|
| |
|
| | col_a, col_b = st.columns(2)
|
| |
|
| | with col_a:
|
| | if st.button("Login", use_container_width=True):
|
| | if username in USERS and USERS[username] == password:
|
| | st.session_state.authenticated = True
|
| | st.session_state.username = username
|
| | st.success("โ
Login successful!")
|
| | st.rerun()
|
| | else:
|
| | st.error("โ Invalid credentials")
|
| |
|
| | with col_b:
|
| | if st.button("Clear", use_container_width=True):
|
| | st.rerun()
|
| |
|
| | st.markdown("---")
|
| | st.info("""
|
| | **Demo Credentials:**
|
| | - admin / admin123
|
| | - doctor / doc123
|
| | - student / student123
|
| | """)
|
| |
|
| | st.stop()
|
| |
|
| |
|
| | @st.cache_resource(show_spinner=False)
|
| | def load_model():
|
| | """Load the fine-tuned TinyLlama model with LoRA adapters"""
|
| | try:
|
| | base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| | lora_path = "./tinyllama-medical-lora"
|
| |
|
| |
|
| | if not os.path.exists(lora_path):
|
| | st.error(f"โ Model not found at {lora_path}")
|
| | st.info("Using base model without fine-tuning...")
|
| | lora_path = None
|
| |
|
| |
|
| | bnb_config = BitsAndBytesConfig(
|
| | load_in_4bit=True,
|
| | bnb_4bit_quant_type="nf4",
|
| | bnb_4bit_compute_dtype=torch.bfloat16,
|
| | bnb_4bit_use_double_quant=True
|
| | )
|
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | base_model_name,
|
| | quantization_config=bnb_config,
|
| | device_map="auto",
|
| | trust_remote_code=True
|
| | )
|
| |
|
| |
|
| | if lora_path:
|
| | model = PeftModel.from_pretrained(model, lora_path)
|
| | st.success("โ
Fine-tuned model loaded successfully!")
|
| | else:
|
| | st.warning("โ ๏ธ Using base model (not fine-tuned)")
|
| |
|
| | model.eval()
|
| |
|
| | return tokenizer, model
|
| |
|
| | except Exception as e:
|
| | st.error(f"Error loading model: {str(e)}")
|
| | return None, None
|
| |
|
| |
|
| | st.title("๐ฉบ TinyLlama Medical Assistant")
|
| | st.caption(f"Logged in as: **{st.session_state.username}**")
|
| |
|
| |
|
| | with st.sidebar:
|
| | st.header("โ๏ธ Settings")
|
| |
|
| |
|
| | if not st.session_state.model_loaded:
|
| | with st.spinner("Loading fine-tuned model..."):
|
| | tokenizer, model = load_model()
|
| | if tokenizer and model:
|
| | st.session_state.tokenizer = tokenizer
|
| | st.session_state.model = model
|
| | st.session_state.model_loaded = True
|
| |
|
| | st.markdown("---")
|
| |
|
| |
|
| | st.subheader("Generation Parameters")
|
| | temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
|
| | max_tokens = st.slider("Max New Tokens", 32, 256, 100, 8)
|
| | top_p = st.slider("Top-p", 0.1, 1.0, 0.9, 0.05)
|
| |
|
| | st.markdown("---")
|
| |
|
| |
|
| | st.subheader("๐ก Example Queries")
|
| | example_queries = [
|
| | "What is Paracetamol used for?",
|
| | "Tell me about Ibuprofen",
|
| | "What is Metformin?",
|
| | "Uses of Amoxicillin",
|
| | "What is Atorvastatin for?"
|
| | ]
|
| |
|
| | for query in example_queries:
|
| | if st.button(query, key=f"example_{query}", use_container_width=True):
|
| | st.session_state.messages.append({"role": "user", "content": query})
|
| | st.rerun()
|
| |
|
| | st.markdown("---")
|
| |
|
| |
|
| | if st.button("๐๏ธ Clear Chat", use_container_width=True):
|
| | st.session_state.messages = []
|
| | st.rerun()
|
| |
|
| |
|
| | if st.button("๐ช Logout", use_container_width=True):
|
| | st.session_state.authenticated = False
|
| | st.session_state.messages = []
|
| | st.rerun()
|
| |
|
| |
|
| | for message in st.session_state.messages:
|
| | with st.chat_message(message["role"]):
|
| | st.markdown(message["content"])
|
| |
|
| |
|
| | if prompt := st.chat_input("Ask a medical question..."):
|
| |
|
| | st.session_state.messages.append({"role": "user", "content": prompt})
|
| | with st.chat_message("user"):
|
| | st.markdown(prompt)
|
| |
|
| |
|
| | with st.chat_message("assistant"):
|
| | with st.spinner("Thinking..."):
|
| | if st.session_state.model_loaded:
|
| | try:
|
| |
|
| | formatted_prompt = f"""### Instruction:
|
| | {prompt}
|
| |
|
| | ### Response:
|
| | """
|
| |
|
| |
|
| | inputs = st.session_state.tokenizer(
|
| | formatted_prompt,
|
| | return_tensors="pt"
|
| | ).to(st.session_state.model.device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | outputs = st.session_state.model.generate(
|
| | **inputs,
|
| | max_new_tokens=max_tokens,
|
| | temperature=temperature,
|
| | top_p=top_p,
|
| | do_sample=True,
|
| | pad_token_id=st.session_state.tokenizer.eos_token_id
|
| | )
|
| |
|
| |
|
| | response = st.session_state.tokenizer.decode(
|
| | outputs[0],
|
| | skip_special_tokens=True
|
| | )
|
| |
|
| |
|
| | if "### Response:" in response:
|
| | response = response.split("### Response:")[-1].strip()
|
| |
|
| |
|
| | full_response = f"{response}\n\n{MEDICAL_DISCLAIMER}"
|
| |
|
| | st.markdown(full_response)
|
| | st.session_state.messages.append({
|
| | "role": "assistant",
|
| | "content": full_response
|
| | })
|
| |
|
| | except Exception as e:
|
| | error_msg = f"Error generating response: {str(e)}"
|
| | st.error(error_msg)
|
| | st.session_state.messages.append({
|
| | "role": "assistant",
|
| | "content": error_msg
|
| | })
|
| | else:
|
| | error_msg = "Model not loaded. Please refresh the page."
|
| | st.error(error_msg)
|
| | st.session_state.messages.append({
|
| | "role": "assistant",
|
| | "content": error_msg
|
| | })
|
| |
|
| |
|
| | st.markdown("---")
|
| | st.caption("Fine-tuned TinyLlama 1.1B with LoRA on Allopathic Medicine Dataset") |