Spaces:
Paused
Paused
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from typing import List, Dict | |
import time | |
class LlamaDemo: | |
def __init__(self): | |
self.model_name = "meta-llama/Llama-2-7b-chat-hf" | |
# Initialize in lazy loading fashion | |
self._model = None | |
self._tokenizer = None | |
def model(self): | |
if self._model is None: | |
self._model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
return self._model | |
def tokenizer(self): | |
if self._tokenizer is None: | |
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
return self._tokenizer | |
def generate_response(self, prompt: str, max_length: int = 512) -> str: | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.replace(prompt, "").strip() | |
def main(): | |
st.set_page_config( | |
page_title="Llama 3.1 Demo", | |
page_icon="π¦", | |
layout="wide" | |
) | |
st.title("π¦ Llama 3.1 Demo") | |
# Initialize session state | |
if 'llama' not in st.session_state: | |
st.session_state.llama = LlamaDemo() | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
# Chat interface | |
with st.container(): | |
# Display chat history | |
for message in st.session_state.chat_history: | |
role = message["role"] | |
content = message["content"] | |
with st.chat_message(role): | |
st.write(content) | |
# Input for new message | |
if prompt := st.chat_input("What would you like to discuss?"): | |
# Add user message to chat history | |
st.session_state.chat_history.append({ | |
"role": "user", | |
"content": prompt | |
}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Show assistant response | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
with st.spinner("Generating response..."): | |
response = st.session_state.llama.generate_response(prompt) | |
message_placeholder.write(response) | |
# Add assistant response to chat history | |
st.session_state.chat_history.append({ | |
"role": "assistant", | |
"content": response | |
}) | |
# Sidebar with settings | |
with st.sidebar: | |
st.header("Settings") | |
max_length = st.slider("Maximum response length", 64, 1024, 512) | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() |