AIdeaText commited on
Commit
55ca2dd
·
verified ·
1 Parent(s): b60aacc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -70
app.py CHANGED
@@ -1,71 +1,54 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- from typing import List, Dict
5
- import time
6
 
7
  class LlamaDemo:
8
  def __init__(self):
9
- # Using TinyLlama, which is open source and doesn't require authentication
10
- self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
  # Initialize in lazy loading fashion
12
- self._model = None
13
- self._tokenizer = None
14
 
15
  @property
16
- def model(self):
17
- if self._model is None:
18
- self._model = AutoModelForCausalLM.from_pretrained(
19
- self.model_name,
 
20
  torch_dtype=torch.float16,
21
  device_map="auto",
22
  trust_remote_code=True
23
  )
24
- return self._model
25
-
26
- @property
27
- def tokenizer(self):
28
- if self._tokenizer is None:
29
- self._tokenizer = AutoTokenizer.from_pretrained(
30
- self.model_name,
31
- trust_remote_code=True
32
- )
33
- return self._tokenizer
34
 
35
  def generate_response(self, prompt: str, max_length: int = 512) -> str:
36
- # Format the prompt according to TinyLlama's chat template
37
- chat_prompt = f"<|system|>You are a helpful AI assistant.</s><|user|>{prompt}</s><|assistant|>"
38
 
39
- inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.model.device)
 
 
 
 
 
 
 
 
40
 
41
- # Generate response
42
- with torch.no_grad():
43
- outputs = self.model.generate(
44
- **inputs,
45
- max_new_tokens=max_length,
46
- num_return_sequences=1,
47
- temperature=0.7,
48
- do_sample=True,
49
- pad_token_id=self.tokenizer.eos_token_id
50
- )
51
-
52
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
53
- # Remove the prompt from the response
54
- response = response.split("<|assistant|>")[-1].strip()
55
- return response
56
 
57
  def main():
58
  st.set_page_config(
59
- page_title="Open Source Llama Demo",
60
  page_icon="🦙",
61
  layout="wide"
62
  )
63
 
64
- st.title("🦙 Open Source Llama Demo")
65
 
66
- # Initialize session state
67
  if 'llama' not in st.session_state:
68
- with st.spinner("Loading model... This might take a few minutes..."):
69
  st.session_state.llama = LlamaDemo()
70
 
71
  if 'chat_history' not in st.session_state:
@@ -73,17 +56,11 @@ def main():
73
 
74
  # Chat interface
75
  with st.container():
76
- # Display chat history
77
  for message in st.session_state.chat_history:
78
- role = message["role"]
79
- content = message["content"]
80
-
81
- with st.chat_message(role):
82
- st.write(content)
83
 
84
- # Input for new message
85
  if prompt := st.chat_input("What would you like to discuss?"):
86
- # Add user message to chat history
87
  st.session_state.chat_history.append({
88
  "role": "user",
89
  "content": prompt
@@ -92,32 +69,24 @@ def main():
92
  with st.chat_message("user"):
93
  st.write(prompt)
94
 
95
- # Show assistant response
96
  with st.chat_message("assistant"):
97
- message_placeholder = st.empty()
98
-
99
  with st.spinner("Thinking..."):
100
- response = st.session_state.llama.generate_response(prompt)
101
- message_placeholder.write(response)
102
-
103
- # Add assistant response to chat history
104
- st.session_state.chat_history.append({
105
- "role": "assistant",
106
- "content": response
107
- })
 
108
 
109
- # Sidebar with settings and info
110
  with st.sidebar:
111
- st.header("Settings")
112
- max_length = st.slider("Maximum response length", 64, 1024, 512)
113
-
114
- st.markdown("---")
115
  st.markdown("""
116
  ### About
117
- This demo uses TinyLlama, an open source language model that's smaller but
118
- still capable. It's perfect for demonstrations and testing.
119
 
120
- The model is loaded locally and doesn't require any authentication or API keys.
121
  """)
122
 
123
  if st.button("Clear Chat History"):
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
 
 
4
 
5
  class LlamaDemo:
6
  def __init__(self):
7
+ self.model_name = "meta-llama/Llama-2-70b-chat-hf"
 
8
  # Initialize in lazy loading fashion
9
+ self._pipe = None
 
10
 
11
  @property
12
+ def pipe(self):
13
+ if self._pipe is None:
14
+ self._pipe = pipeline(
15
+ "text-generation",
16
+ model=self.model_name,
17
  torch_dtype=torch.float16,
18
  device_map="auto",
19
  trust_remote_code=True
20
  )
21
+ return self._pipe
 
 
 
 
 
 
 
 
 
22
 
23
  def generate_response(self, prompt: str, max_length: int = 512) -> str:
24
+ # Format prompt for Llama 2 chat
25
+ formatted_prompt = f"[INST] {prompt} [/INST]"
26
 
27
+ # Generate response using pipeline
28
+ response = self.pipe(
29
+ formatted_prompt,
30
+ max_new_tokens=max_length,
31
+ num_return_sequences=1,
32
+ temperature=0.7,
33
+ do_sample=True,
34
+ top_p=0.9
35
+ )[0]['generated_text']
36
 
37
+ # Extract response after the instruction tag
38
+ return response.split("[/INST]")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def main():
41
  st.set_page_config(
42
+ page_title="Llama 2 Chat Demo",
43
  page_icon="🦙",
44
  layout="wide"
45
  )
46
 
47
+ st.title("🦙 Llama 2 Chat Demo")
48
 
49
+ # Initialize model
50
  if 'llama' not in st.session_state:
51
+ with st.spinner("Loading Llama 2... This might take a few minutes..."):
52
  st.session_state.llama = LlamaDemo()
53
 
54
  if 'chat_history' not in st.session_state:
 
56
 
57
  # Chat interface
58
  with st.container():
 
59
  for message in st.session_state.chat_history:
60
+ with st.chat_message(message["role"]):
61
+ st.write(message["content"])
 
 
 
62
 
 
63
  if prompt := st.chat_input("What would you like to discuss?"):
 
64
  st.session_state.chat_history.append({
65
  "role": "user",
66
  "content": prompt
 
69
  with st.chat_message("user"):
70
  st.write(prompt)
71
 
 
72
  with st.chat_message("assistant"):
 
 
73
  with st.spinner("Thinking..."):
74
+ try:
75
+ response = st.session_state.llama.generate_response(prompt)
76
+ st.write(response)
77
+ st.session_state.chat_history.append({
78
+ "role": "assistant",
79
+ "content": response
80
+ })
81
+ except Exception as e:
82
+ st.error(f"Error: {str(e)}")
83
 
 
84
  with st.sidebar:
 
 
 
 
85
  st.markdown("""
86
  ### About
87
+ This demo uses Llama-2-70B-chat, a large language model from Meta.
 
88
 
89
+ The model runs with automatic device mapping and mixed precision for optimal performance.
90
  """)
91
 
92
  if st.button("Clear Chat History"):