junaidbaber commited on
Commit
97a2367
·
verified ·
1 Parent(s): 03b1321

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -118
app.py CHANGED
@@ -1,144 +1,82 @@
1
  import streamlit as st
2
- from huggingface_hub import login
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
- from transformers import BitsAndBytesConfig
6
  import os
7
 
8
  def initialize_model():
9
- """Initialize the model and tokenizer with CPU support"""
10
- # Log in to Hugging Face
11
- token = os.environ.get("hf")
12
- if token:
13
- login(token)
14
-
15
- model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16
-
17
- # Load tokenizer
18
- tokenizer = AutoTokenizer.from_pretrained(model_id)
19
 
20
  try:
21
- # Try with regular CPU mode first (simpler and more reliable)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- model_id,
 
24
  device_map="cpu",
25
- trust_remote_code=True,
26
- low_cpu_mem_usage=True
27
  )
 
 
 
28
  except Exception as e:
29
  print(f"Error loading model: {str(e)}")
30
  raise e
31
 
32
- # Ensure padding token is defined
33
- if tokenizer.pad_token is None:
34
- tokenizer.pad_token = tokenizer.eos_token
35
-
36
- return model, tokenizer
37
-
38
- def format_prompt(user_input, conversation_history=[]):
39
- """Format the prompt according to TinyLlama's expected chat format"""
40
- messages = []
41
-
42
- # Add conversation history
43
- for turn in conversation_history:
44
- messages.append({"role": "user", "content": turn["user"]})
45
- messages.append({"role": "assistant", "content": turn["assistant"]})
46
-
47
- # Add current user input
48
- messages.append({"role": "user", "content": user_input})
49
-
50
- # Format into TinyLlama chat format
51
- formatted_prompt = "<|system|>You are a helpful AI assistant.</s>"
52
-
53
- for message in messages:
54
- if message["role"] == "user":
55
- formatted_prompt += f"<|user|>{message['content']}</s>"
56
- else:
57
- formatted_prompt += f"<|assistant|>{message['content']}</s>"
58
-
59
- formatted_prompt += "<|assistant|>"
60
- return formatted_prompt
61
-
62
- def generate_response(model, tokenizer, prompt, conversation_history):
63
  """Generate model response"""
64
  try:
65
- # Format prompt using TinyLlama's chat template
66
- formatted_prompt = format_prompt(prompt, conversation_history[:-1])
 
 
67
 
68
- # Tokenize input
69
- inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True)
70
 
71
- # Move inputs to the same device as the model
72
- device = next(model.parameters()).device
73
- inputs = {k: v.to(device) for k, v in inputs.items()}
74
-
75
- # Calculate max new tokens
76
- input_length = inputs["input_ids"].shape[1]
77
- max_model_length = 1024
78
- max_new_tokens = min(150, max_model_length - input_length)
79
-
80
- # Generate response
81
- outputs = model.generate(
82
- inputs["input_ids"],
83
- attention_mask=inputs["attention_mask"],
84
- max_new_tokens=max_new_tokens,
85
  temperature=0.7,
86
  top_p=0.9,
87
- pad_token_id=tokenizer.pad_token_id,
88
  do_sample=True,
89
- min_length=10,
90
- no_repeat_ngram_size=3,
91
- eos_token_id=tokenizer.encode("</s>")[0] # Set end token
92
- )
93
-
94
- # Decode response and extract only the assistant's message
95
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
96
 
97
- # Extract only the last assistant response
98
- assistant_response = full_response.split("<|assistant|>")[-1].split("</s>")[0].strip()
99
-
100
- return assistant_response if assistant_response else "I apologize, but I couldn't generate a proper response."
 
 
 
 
101
 
102
- except RuntimeError as e:
103
- if "out of memory" in str(e):
104
- torch.cuda.empty_cache()
105
- return "I apologize, but I ran out of memory. Please try a shorter message or clear the chat history."
106
- else:
107
- return f"An error occurred: {str(e)}"
108
 
109
  def main():
110
- st.set_page_config(
111
- page_title="LLM Chat Interface",
112
- page_icon="🤖",
113
- layout="wide"
114
- )
115
 
116
- # Add CSS to make the chat interface more compact
117
- st.markdown("""
118
- <style>
119
- .stChat {
120
- padding-top: 0rem;
121
- }
122
- .stChatMessage {
123
- padding: 0.5rem;
124
- }
125
- </style>
126
- """, unsafe_allow_html=True)
127
-
128
- st.title("Chat with TinyLlama 🤖")
129
 
130
- # Initialize session state for chat history
131
  if "chat_history" not in st.session_state:
132
  st.session_state.chat_history = []
 
 
 
133
 
134
  # Initialize model (only once)
135
- if "model" not in st.session_state:
136
- with st.spinner("Loading the model... This might take a minute..."):
137
  try:
138
- model, tokenizer = initialize_model()
139
- st.session_state.model = model
140
  st.session_state.tokenizer = tokenizer
141
- st.success("Model loaded successfully!")
142
  except Exception as e:
143
  st.error(f"Error loading model: {str(e)}")
144
  return
@@ -151,7 +89,7 @@ def main():
151
  st.write(message["assistant"])
152
 
153
  # Chat input
154
- if prompt := st.chat_input("What would you like to know?"):
155
  # Display user message
156
  with st.chat_message("user"):
157
  st.write(prompt)
@@ -163,7 +101,7 @@ def main():
163
  st.session_state.chat_history.append(current_turn)
164
 
165
  response = generate_response(
166
- st.session_state.model,
167
  st.session_state.tokenizer,
168
  prompt,
169
  st.session_state.chat_history
@@ -172,23 +110,22 @@ def main():
172
  st.write(response)
173
  st.session_state.chat_history[-1]["assistant"] = response
174
 
175
- # Manage context window
176
  if len(st.session_state.chat_history) > 5:
177
  st.session_state.chat_history = st.session_state.chat_history[-5:]
178
 
179
- # Sidebar controls
180
  with st.sidebar:
181
- st.title("Controls")
182
  if st.button("Clear Chat"):
183
  st.session_state.chat_history = []
184
  st.rerun()
185
 
186
  st.markdown("---")
187
  st.markdown("""
188
- ### Model Info
189
- - Using TinyLlama 1.1B Chat
190
- - CPU optimized
191
- - Context window: 1024 tokens
192
  """)
193
 
194
  if __name__ == "__main__":
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
 
4
  import os
5
 
6
  def initialize_model():
7
+ """Initialize a small and fast model for CPU"""
8
+ # Using a tiny model optimized for CPU
9
+ model_id = "facebook/opt-125m" # Much smaller model (125M parameters)
 
 
 
 
 
 
 
10
 
11
  try:
12
+ # Initialize the pipeline directly - more efficient than loading model separately
13
+ pipe = pipeline(
14
+ "text-generation",
15
+ model=model_id,
16
  device_map="cpu",
17
+ model_kwargs={"low_cpu_mem_usage": True}
 
18
  )
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+
21
+ return pipe, tokenizer
22
  except Exception as e:
23
  print(f"Error loading model: {str(e)}")
24
  raise e
25
 
26
+ def generate_response(pipe, tokenizer, prompt, conversation_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """Generate model response"""
28
  try:
29
+ # Format conversation context
30
+ context = ""
31
+ for turn in conversation_history[-3:]: # Only use last 3 turns for efficiency
32
+ context += f"Human: {turn['user']}\nAssistant: {turn['assistant']}\n"
33
 
34
+ # Create the full prompt
35
+ full_prompt = f"{context}Human: {prompt}\nAssistant:"
36
 
37
+ # Generate response with conservative parameters
38
+ response = pipe(
39
+ full_prompt,
40
+ max_new_tokens=50, # Limit response length
 
 
 
 
 
 
 
 
 
 
41
  temperature=0.7,
42
  top_p=0.9,
43
+ num_return_sequences=1,
44
  do_sample=True,
45
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
46
+ )[0]['generated_text']
 
 
 
 
 
47
 
48
+ # Extract only the assistant's response
49
+ try:
50
+ assistant_response = response.split("Assistant:")[-1].strip()
51
+ if not assistant_response:
52
+ return "I apologize, but I couldn't generate a proper response."
53
+ return assistant_response
54
+ except:
55
+ return response.split(prompt)[-1].strip()
56
 
57
+ except Exception as e:
58
+ return f"An error occurred: {str(e)}"
 
 
 
 
59
 
60
  def main():
61
+ st.set_page_config(page_title="LLM Chat Interface", page_icon="🤖")
 
 
 
 
62
 
63
+ st.title("💬 Quick Chat Assistant")
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Initialize session state
66
  if "chat_history" not in st.session_state:
67
  st.session_state.chat_history = []
68
+
69
+ if "model_loaded" not in st.session_state:
70
+ st.session_state.model_loaded = False
71
 
72
  # Initialize model (only once)
73
+ if not st.session_state.model_loaded:
74
+ with st.spinner("Loading the model... (this should take just a few seconds)"):
75
  try:
76
+ pipe, tokenizer = initialize_model()
77
+ st.session_state.pipe = pipe
78
  st.session_state.tokenizer = tokenizer
79
+ st.session_state.model_loaded = True
80
  except Exception as e:
81
  st.error(f"Error loading model: {str(e)}")
82
  return
 
89
  st.write(message["assistant"])
90
 
91
  # Chat input
92
+ if prompt := st.chat_input("Ask me anything!"):
93
  # Display user message
94
  with st.chat_message("user"):
95
  st.write(prompt)
 
101
  st.session_state.chat_history.append(current_turn)
102
 
103
  response = generate_response(
104
+ st.session_state.pipe,
105
  st.session_state.tokenizer,
106
  prompt,
107
  st.session_state.chat_history
 
110
  st.write(response)
111
  st.session_state.chat_history[-1]["assistant"] = response
112
 
113
+ # Keep only last 5 turns
114
  if len(st.session_state.chat_history) > 5:
115
  st.session_state.chat_history = st.session_state.chat_history[-5:]
116
 
117
+ # Sidebar
118
  with st.sidebar:
 
119
  if st.button("Clear Chat"):
120
  st.session_state.chat_history = []
121
  st.rerun()
122
 
123
  st.markdown("---")
124
  st.markdown("""
125
+ ### Chat Info
126
+ - Using OPT-125M model
127
+ - Optimized for quick responses
128
+ - Best for short conversations
129
  """)
130
 
131
  if __name__ == "__main__":