Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain import LLMChain | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.prompts import PromptTemplate | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import logging | |
import torch | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class MedicalChatbot: | |
def __init__(self): | |
"""Initialize the Medical Chatbot without a PDF document""" | |
logger.info("Initializing Medical Chatbot with pre-trained model knowledge") | |
# Initialize components | |
self.setup_model() | |
# Setup memory | |
self.memory = ConversationBufferWindowMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
k=3 | |
) | |
def setup_model(self): | |
"""Initialize the LaMini model""" | |
try: | |
model_id = "MBZUAI/LaMini-Flan-T5-783M" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
self.pipe = pipeline( | |
"text2text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
max_length=512, | |
do_sample=True, | |
temperature=0.3, | |
top_p=0.95, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
logger.info("Model initialized successfully") | |
except Exception as e: | |
logger.error(f"Model initialization failed: {str(e)}") | |
raise | |
def generate_response(self, user_input: str) -> str: | |
"""Generate a response to the user's question using model knowledge""" | |
try: | |
# Create prompt | |
prompt = PromptTemplate( | |
input_variables=["question", "chat_history"], | |
template=""" | |
Use your knowledge and the conversation history to answer the question. | |
If you're unsure, say so and suggest consulting a healthcare professional. | |
Chat History: {chat_history} | |
Question: {question} | |
Answer:""" | |
) | |
# Generate response | |
chain = LLMChain( | |
llm=self.pipe, | |
prompt=prompt, | |
memory=self.memory | |
) | |
response = chain.run( | |
question=user_input | |
) | |
return response + self.get_disclaimer() | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
return "I apologize, but I encountered an error. Please try again." | |
def get_disclaimer(self) -> str: | |
return "\n\nDISCLAIMER: This information is for educational purposes only. Please consult healthcare professionals for medical advice." | |
def init_session_state(): | |
"""Initialize session state variables""" | |
if 'chatbot' not in st.session_state: | |
st.session_state.chatbot = None | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
def main(): | |
# Page configuration | |
st.set_page_config( | |
page_title="Medical Chatbot", | |
page_icon="π₯", | |
layout="wide" | |
) | |
# Initialize session state | |
init_session_state() | |
st.title("Medical Chatbot Assistant π₯") | |
# Sidebar | |
with st.sidebar: | |
st.header("Configuration") | |
# New chat button | |
if st.button("New Chat"): | |
st.session_state.messages = [] | |
if st.session_state.chatbot: | |
st.session_state.chatbot.memory.clear() | |
st.rerun() | |
# Initialize chatbot if needed | |
if not st.session_state.chatbot: | |
with st.spinner("Initializing chatbot..."): | |
try: | |
st.session_state.chatbot = MedicalChatbot() | |
st.success("Chatbot initialized successfully!") | |
except Exception as e: | |
st.error(f"Error initializing chatbot: {str(e)}") | |
# Chat interface | |
if st.session_state.chatbot: | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Chat input | |
if prompt := st.chat_input("Ask your medical question..."): | |
# Add user message | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Generate response | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
response = st.session_state.chatbot.generate_response(prompt) | |
st.write(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
main() |