import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Set page configuration st.set_page_config( page_title="Apertus-8B Chat", page_icon="🤖", layout="wide" ) # Add a title to the app st.title("🤖 Chat with Apertus-8B-Instruct") st.caption("A Streamlit app running swiss-ai/Apertus-8B-Instruct-2509") # --- MODEL LOADING --- @st.cache_resource def load_model(): """Loads the model and tokenizer with 4-bit quantization.""" model_id = "swiss-ai/Apertus-8B-Instruct-2509" # Configure quantization to reduce memory usage bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) # Load the model model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", # Automatically maps model layers to available hardware (GPU/CPU) ) return tokenizer, model # Load the model and display a spinner while doing so with st.spinner("Loading Apertus-8B model... This might take a moment."): tokenizer, model = load_model() # --- CHAT INTERFACE --- # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Accept user input if prompt := st.chat_input("What would you like to ask?"): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) # Display user message in chat message container with st.chat_message("user"): st.markdown(prompt) # --- GENERATION --- with st.chat_message("assistant"): with st.spinner("Thinking..."): # Prepare the input for the model input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate a response outputs = model.generate( **input_ids, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95 ) # Decode and display the response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # The model often repeats the prompt, so we can clean it up cleaned_response = response.replace(prompt, "").strip() st.markdown(cleaned_response) # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": cleaned_response})