test-space / src /streamlit_app.py
akshil-jain's picture
Update src/streamlit_app.py
765555e verified
import streamlit as st
from transformers import pipeline
import torch
# Set the title of the Streamlit app
st.set_page_config(page_title="Hugging Face Chat", page_icon="πŸ€—")
st.title("πŸ€— Hugging Face Model Chat")
# Add a sidebar for model selection
with st.sidebar:
st.header("Model Selection")
# A dictionary of available models
model_options = {
"NVIDIA Nemotron 3 8B": "nvidia/nemotron-3-8b-chat-4k-sft",
"Meta Llama 3.1 8B": "meta-llama/Llama-3.1-8B-Instruct",
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
"Gemma 7B It": "google/gemma-7b-it",
}
selected_model_name = st.selectbox("Choose a model:", list(model_options.keys()))
model_id = model_options[selected_model_name]
st.markdown("---")
st.markdown("This app allows you to chat with different open-source Large Language Models from the Hugging Face Hub.")
st.markdown("Select a model from the dropdown and start chatting!")
# Caching the model loading to improve performance
@st.cache_resource
def load_model(model_id):
"""Loads the selected model and tokenizer from Hugging Face."""
try:
# Use "text-generation" pipeline for chat models
pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
return pipe
except Exception as e:
st.error(f"Error loading model: {e}")
return None
# Load the selected model
pipe = load_model(model_id)
# Initialize chat history in session state
if "messages" not in st.session_state:
st.session_state.messages = []
# Display prior chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Get user input
if prompt := st.chat_input("What would you like to ask?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Generate a response from the model
if pipe:
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
# Prepare the prompt for the model
# Note: Different models may have different prompt formats.
# This is a generic approach.
formatted_prompt = f"User: {prompt}\nAssistant:"
# Generate the response
response = pipe(
formatted_prompt,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=50
)
# Extract the generated text
if response and len(response) > 0 and "generated_text" in response[0]:
# The output often includes the prompt, so we clean it up.
assistant_response = response[0]["generated_text"].split("Assistant:")[-1].strip()
else:
assistant_response = "Sorry, I couldn't generate a response."
st.markdown(assistant_response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": assistant_response})
else:
st.error("Model not loaded. Cannot generate a response.")