Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Hugging Face repository details | |
MODEL_ID = "meta-llama/CodeLlama-7b-Instruct-hf" | |
def load_model(): | |
"""Load the Hugging Face model and tokenizer.""" | |
try: | |
st.write("Loading model and tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, device_map="auto", torch_dtype=torch.float16 | |
) | |
st.write("Model and tokenizer successfully loaded.") | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return None, None | |
# Load the model and tokenizer | |
def get_model(): | |
return load_model() | |
tokenizer, model = get_model() | |
# Streamlit UI | |
st.title("Medical Chatbot") | |
st.write("This chatbot provides medical assistance. Type your question below!") | |
if model is None or tokenizer is None: | |
st.error("Model failed to load. Please check the Hugging Face model path or environment configuration.") | |
else: | |
user_input = st.text_input("You:", placeholder="Enter your medical question here...", key="input_box") | |
if st.button("Send"): | |
if user_input.strip(): | |
# Construct the prompt | |
SYSTEM_PROMPT = "You are a helpful medical assistant. Provide accurate and concise answers." | |
full_prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nAssistant:" | |
# Tokenize the input | |
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True).to("cuda") | |
try: | |
# Generate the response | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_length=200, # Limit response length | |
temperature=0.7, # Control randomness | |
top_p=0.9, # Top-p sampling | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and display the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("Assistant:")[-1].strip() | |
st.write(f"**Model:** {response}") | |
except Exception as e: | |
st.error(f"Error generating response: {e}") | |
else: | |
st.warning("Please enter a valid question.") | |