Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import torch | |
# Set the title of the Streamlit app | |
st.set_page_config(page_title="Hugging Face Chat", page_icon="π€") | |
st.title("π€ Hugging Face Model Chat") | |
# Add a sidebar for model selection | |
with st.sidebar: | |
st.header("Model Selection") | |
# A dictionary of available models | |
model_options = { | |
"NVIDIA Nemotron 3 8B": "nvidia/nemotron-3-8b-chat-4k-sft", | |
"Meta Llama 3.1 8B": "meta-llama/Llama-3.1-8B-Instruct", | |
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.1", | |
"Gemma 7B It": "google/gemma-7b-it", | |
} | |
selected_model_name = st.selectbox("Choose a model:", list(model_options.keys())) | |
model_id = model_options[selected_model_name] | |
st.markdown("---") | |
st.markdown("This app allows you to chat with different open-source Large Language Models from the Hugging Face Hub.") | |
st.markdown("Select a model from the dropdown and start chatting!") | |
# Caching the model loading to improve performance | |
def load_model(model_id): | |
"""Loads the selected model and tokenizer from Hugging Face.""" | |
try: | |
# Use "text-generation" pipeline for chat models | |
pipe = pipeline( | |
"text-generation", | |
model=model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
return pipe | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return None | |
# Load the selected model | |
pipe = load_model(model_id) | |
# Initialize chat history in session state | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display prior chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Get user input | |
if prompt := st.chat_input("What would you like to ask?"): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate a response from the model | |
if pipe: | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
# Prepare the prompt for the model | |
# Note: Different models may have different prompt formats. | |
# This is a generic approach. | |
formatted_prompt = f"User: {prompt}\nAssistant:" | |
# Generate the response | |
response = pipe( | |
formatted_prompt, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
top_k=50 | |
) | |
# Extract the generated text | |
if response and len(response) > 0 and "generated_text" in response[0]: | |
# The output often includes the prompt, so we clean it up. | |
assistant_response = response[0]["generated_text"].split("Assistant:")[-1].strip() | |
else: | |
assistant_response = "Sorry, I couldn't generate a response." | |
st.markdown(assistant_response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": assistant_response}) | |
else: | |
st.error("Model not loaded. Cannot generate a response.") | |