TestOneLlama / app.py
AIdeaText's picture
Create app.py
f995cde verified
raw
history blame
3.46 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Dict
import time
class LlamaDemo:
def __init__(self):
self.model_name = "meta-llama/Llama-2-7b-chat-hf"
# Initialize in lazy loading fashion
self._model = None
self._tokenizer = None
@property
def model(self):
if self._model is None:
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto"
)
return self._model
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return self._tokenizer
def generate_response(self, prompt: str, max_length: int = 512) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.replace(prompt, "").strip()
def main():
st.set_page_config(
page_title="Llama 3.1 Demo",
page_icon="πŸ¦™",
layout="wide"
)
st.title("πŸ¦™ Llama 3.1 Demo")
# Initialize session state
if 'llama' not in st.session_state:
st.session_state.llama = LlamaDemo()
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Chat interface
with st.container():
# Display chat history
for message in st.session_state.chat_history:
role = message["role"]
content = message["content"]
with st.chat_message(role):
st.write(content)
# Input for new message
if prompt := st.chat_input("What would you like to discuss?"):
# Add user message to chat history
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
with st.chat_message("user"):
st.write(prompt)
# Show assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
with st.spinner("Generating response..."):
response = st.session_state.llama.generate_response(prompt)
message_placeholder.write(response)
# Add assistant response to chat history
st.session_state.chat_history.append({
"role": "assistant",
"content": response
})
# Sidebar with settings
with st.sidebar:
st.header("Settings")
max_length = st.slider("Maximum response length", 64, 1024, 512)
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun()
if __name__ == "__main__":
main()