|
import streamlit as st |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
|
|
|
st.set_page_config(page_title="ChatDoctor", page_icon="🩺") |
|
|
|
|
|
st.title("🩺 ChatDoctor - Medical Assistant") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
|
|
model = AutoModelForCausalLM.from_pretrained("abhiyanta/chatDoctor", use_cache=True) |
|
tokenizer = AutoTokenizer.from_pretrained("abhiyanta/chatDoctor") |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
alpaca_prompt = "### Instruction:\n{0}\n\n### Input:\n{1}\n\n### Output:\n{2}" |
|
|
|
|
|
user_input = st.text_input("Ask your medical question:") |
|
|
|
|
|
if st.button("Ask ChatDoctor"): |
|
if user_input: |
|
|
|
formatted_prompt = alpaca_prompt.format( |
|
user_input, |
|
"", |
|
"" |
|
) |
|
|
|
|
|
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cpu") |
|
|
|
|
|
st.write("**ChatDoctor:**") |
|
text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
with st.spinner('Generating response...'): |
|
generated_ids = model.generate(**inputs, streamer=text_streamer, max_new_tokens=1000) |
|
|
|
else: |
|
st.warning("Please enter a question to ask ChatDoctor.") |
|
|
|
|
|
st.markdown("---") |
|
st.caption("Powered by Hugging Face 🤗") |
|
|