import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch from typing import List, Dict import time class LlamaDemo: def __init__(self): self.model_name = "meta-llama/Llama-2-7b-chat-hf" # Initialize in lazy loading fashion self._model = None self._tokenizer = None @property def model(self): if self._model is None: self._model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, device_map="auto" ) return self._model @property def tokenizer(self): if self._tokenizer is None: self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) return self._tokenizer def generate_response(self, prompt: str, max_length: int = 512) -> str: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, num_return_sequences=1, temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response.replace(prompt, "").strip() def main(): st.set_page_config( page_title="Llama 3.1 Demo", page_icon="🦙", layout="wide" ) st.title("🦙 Llama 3.1 Demo") # Initialize session state if 'llama' not in st.session_state: st.session_state.llama = LlamaDemo() if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Chat interface with st.container(): # Display chat history for message in st.session_state.chat_history: role = message["role"] content = message["content"] with st.chat_message(role): st.write(content) # Input for new message if prompt := st.chat_input("What would you like to discuss?"): # Add user message to chat history st.session_state.chat_history.append({ "role": "user", "content": prompt }) with st.chat_message("user"): st.write(prompt) # Show assistant response with st.chat_message("assistant"): message_placeholder = st.empty() with st.spinner("Generating response..."): response = st.session_state.llama.generate_response(prompt) message_placeholder.write(response) # Add assistant response to chat history st.session_state.chat_history.append({ "role": "assistant", "content": response }) # Sidebar with settings with st.sidebar: st.header("Settings") max_length = st.slider("Maximum response length", 64, 1024, 512) if st.button("Clear Chat History"): st.session_state.chat_history = [] st.experimental_rerun() if __name__ == "__main__": main()