kouki321 commited on
Commit
e2c4e20
Β·
verified Β·
1 Parent(s): fe55efb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -22
app.py CHANGED
@@ -5,6 +5,7 @@ from transformers.cache_utils import DynamicCache
5
  import os
6
  from time import time
7
  import pandas as pd
 
8
 
9
 
10
  # ==============================
@@ -80,7 +81,7 @@ def load_model_and_tokenizer(doc_text_count):
80
  tokenizer = AutoTokenizer.from_pretrained(
81
  model_name,
82
  trust_remote_code=True,
83
- model_max_length= 1.5*round(doc_text_count * 0.3 + 1)
84
  )
85
  model = AutoModelForCausalLM.from_pretrained(
86
  model_name,
@@ -103,7 +104,7 @@ def load_document_and_cache(file_path):
103
  t2=time()
104
  with open(file_path, "r", encoding="utf-8") as f:
105
  doc_text = f.read()
106
- doc_text_count= len(doc_text)
107
  model, tokenizer = load_model_and_tokenizer(doc_text_count)
108
  system_prompt = f"""
109
  <|system|>
@@ -116,11 +117,108 @@ def load_document_and_cache(file_path):
116
  cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
117
  t3=time()
118
  print(f"{t3-t2}")
119
- return cache, origin_len, doc_text ,doc_text_count,model, tokenizer
120
  except FileNotFoundError:
121
  st.error(f"Document file not found at {file_path}")
122
  return None, None, None, None
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # ==============================
125
  # Main Streamlit UI and Workflow
126
  # ==============================
@@ -129,22 +227,24 @@ st.title("πŸš€ DeepSeek QA: Supercharged Caching & Memory Dashboard")
129
 
130
  uploaded_file = st.file_uploader("πŸ“ Upload your document (.txt)", type="txt")
131
  doc_text = None
132
- doc_text_count= None
133
  cache = None
134
  origin_len = None
135
  last_generation_time = None
136
-
137
  t1 = time()
138
  if uploaded_file:
139
  temp_file_path = "temp_document.txt"
140
  with open(temp_file_path, "wb") as f:
141
  f.write(uploaded_file.getvalue())
142
- cache, origin_len, doc_text ,doc_text_count,model, tokenizer = load_document_and_cache(temp_file_path)
143
  with st.expander("πŸ“„ Document Preview"):
144
  st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
145
  query = st.text_input("πŸ”Ž Ask a question about the document:")
146
  if query and st.button("Generate Answer"):
147
- with st.spinner("Generating answer..."):
 
 
148
  current_cache = clone_cache(cache)
149
  t_clone_end = time()
150
  Cache_create_time = t_clone_end - t1
@@ -154,33 +254,25 @@ if uploaded_file:
154
  <|assistant|>
155
  """.strip()
156
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
 
157
  t_gen_start = time()
158
  output_ids = generate(model, input_ids, current_cache)
 
 
 
159
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
160
  t_gen_end = time()
161
  last_generation_time = t_gen_end - t_gen_start
162
-
163
  st.success("Answer:")
164
  st.write(response)
165
- st.info(f"Cache create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s ")
166
-
167
-
168
  if st.button("πŸ’Ύ Save Cache"):
169
  torch.save(clean_up(clone_cache(cache), origin_len), "saved_cache.pth")
170
  st.success("Cache saved successfully!")
171
- if query and st.button("calcul cache mb "):
172
- t12=time()
173
- cache_mem_bytes = calculate_cache_size(cache)
174
- t123=time()
175
- time_to=t123-t12
176
-
177
- doc_text = len(doc_text)
178
- st.info(f"time_to_calculate_cache_size: {time_to:} s | cache mem bytes {cache_mem_bytes} MB ")
179
- st.info(f"doc_text_count: {doc_text_count:} char ")
180
-
181
  else:
182
  st.info("Please upload a document to start.")
183
-
184
 
185
  # Sidebar: Load a previously saved cache
186
  st.sidebar.header("πŸ› οΈ Advanced Options")
 
5
  import os
6
  from time import time
7
  import pandas as pd
8
+ import psutil
9
 
10
 
11
  # ==============================
 
81
  tokenizer = AutoTokenizer.from_pretrained(
82
  model_name,
83
  trust_remote_code=True,
84
+ model_max_length=2*round(doc_text_count * 0.3 + 1)
85
  )
86
  model = AutoModelForCausalLM.from_pretrained(
87
  model_name,
 
104
  t2=time()
105
  with open(file_path, "r", encoding="utf-8") as f:
106
  doc_text = f.read()
107
+ doc_text_count = len(doc_text)
108
  model, tokenizer = load_model_and_tokenizer(doc_text_count)
109
  system_prompt = f"""
110
  <|system|>
 
117
  cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
118
  t3=time()
119
  print(f"{t3-t2}")
120
+ return cache, origin_len, doc_text, doc_text_count
121
  except FileNotFoundError:
122
  st.error(f"Document file not found at {file_path}")
123
  return None, None, None, None
124
 
125
+ # ==============================
126
+ # System & Cache Resource Stats
127
+ # ==============================
128
+ def get_system_stats(doc_text=None, cache_mem_bytes=0):
129
+ ram = psutil.virtual_memory()
130
+ cpu = psutil.cpu_percent()
131
+ disk = psutil.disk_usage('/')
132
+ used, total = ram.used, ram.total
133
+ stats = {
134
+ "Input Tokens": st.session_state.get('input_tokens_count', 0),
135
+ "Output Tokens": st.session_state.get('output_tokens_count', 0),
136
+ "Generated Tokens": st.session_state.get('generated_tokens_count', 0),
137
+ "Document Size (chars)": len(doc_text) if doc_text else 0,
138
+ "Document Size (KB)": f"{len(doc_text.encode('utf-8')) / 1024:.2f}" if doc_text else 0,
139
+ }
140
+ if torch.cuda.is_available():
141
+ gpu_mem_alloc = torch.cuda.memory_allocated()
142
+ gpu_mem_total = torch.cuda.get_device_properties(0).total_memory
143
+ stats["GPU Used"] = sizeof_fmt(gpu_mem_alloc)
144
+ stats["GPU Total"] = sizeof_fmt(gpu_mem_total)
145
+ stats["GPU Usage (%)"] = round(100 * gpu_mem_alloc / gpu_mem_total, 2) if gpu_mem_total else 0
146
+ else:
147
+ stats["GPU Used"] = "N/A"
148
+ stats["GPU Total"] = "N/A"
149
+ stats["GPU Usage (%)"] = "N/A"
150
+
151
+ stats["KV Cache Memory Used"] = sizeof_fmt(cache_mem_bytes)
152
+ stats["KV Cache as % RAM"] = f"{(cache_mem_bytes / total) * 100:.2f}%" if total > 0 else "N/A"
153
+ stats["KV Cache as % GPU"] = (
154
+ f"{(cache_mem_bytes / torch.cuda.get_device_properties(0).total_memory) * 100:.2f}%"
155
+ if torch.cuda.is_available() else "N/A"
156
+ )
157
+ return stats
158
+
159
+ def cache_stats_table(cache):
160
+ if cache is None:
161
+ return pd.DataFrame(), 0
162
+ rows = []
163
+ total_mem = 0
164
+ for i, (key, value) in enumerate(zip(cache.key_cache, cache.value_cache)):
165
+ key_mem = key.element_size() * key.nelement()
166
+ value_mem = value.element_size() * value.nelement()
167
+ total_mem += key_mem + value_mem
168
+ row = {
169
+ "Layer": i,
170
+ "Key Shape": str(tuple(key.shape)),
171
+ "Value Shape": str(tuple(value.shape)),
172
+ "Total Mem": sizeof_fmt(key_mem + value_mem),
173
+ "Last Key Tokens": str(tuple(key[..., -1:, :].shape)),
174
+ "Last Value Tokens": str(tuple(value[..., -1:, :].shape)),
175
+ }
176
+ rows.append(row)
177
+ return pd.DataFrame(rows), total_mem
178
+
179
+ def resource_dashboard(cache, doc_text, generation_time=None, cache_clone_time=None):
180
+ cache_df, cache_mem_bytes = cache_stats_table(cache)
181
+ stats = get_system_stats(doc_text, cache_mem_bytes)
182
+ st.sidebar.header("🚦 Live Resource & Cache Dashboard")
183
+ st.sidebar.caption("See how your document and answers use your computer's memory and processing power. The KV Cache lets you answer questions super-fast!")
184
+ stats_table = pd.DataFrame(stats, index=["Value"]).T
185
+ st.sidebar.dataframe(stats_table, use_container_width=True, height=420)
186
+ if torch.cuda.is_available() and stats["GPU Usage (%)"] != "N/A":
187
+ gpu_pct = float(stats["GPU Usage (%)"])
188
+ st.sidebar.progress(int(min(gpu_pct, 100)), text=f"GPU Usage: {gpu_pct:.1f}%")
189
+ cache_pct_str = stats["KV Cache as % RAM"]
190
+ if isinstance(cache_pct_str, str) and cache_pct_str.endswith('%'):
191
+ try:
192
+ cache_pct = float(cache_pct_str[:-1])
193
+ except ValueError:
194
+ cache_pct = 0
195
+ else:
196
+ cache_pct = 0
197
+ st.sidebar.progress(int(min(cache_pct, 100)), text=f"KV Cache as RAM: {cache_pct:.1f}%")
198
+ if generation_time is not None or cache_clone_time is not None:
199
+ time_rows = []
200
+ if generation_time is not None:
201
+ time_rows.append({"Step": "Answer Generation", "Time (s)": f"{generation_time:.2f}"})
202
+ if cache_clone_time is not None:
203
+ time_rows.append({"Step": "Cache Copy", "Time (s)": f"{cache_clone_time:.2f}"})
204
+ st.sidebar.table(pd.DataFrame(time_rows))
205
+ with st.sidebar.expander("🧠 KV Cache Details (per Layer)", expanded=True):
206
+ st.markdown(
207
+ "The table below shows the shape, dtype, size, and memory used for each layer's cache in the neural network. Efficient caching speeds up new questions."
208
+ )
209
+ if not cache_df.empty:
210
+ st.dataframe(cache_df, use_container_width=True, height=340)
211
+ else:
212
+ st.info("No cache yet. Upload a document to see caching details.")
213
+
214
+ # Initialize session state variables
215
+ if 'generated_tokens_count' not in st.session_state:
216
+ st.session_state.generated_tokens_count = 0
217
+ if 'input_tokens_count' not in st.session_state:
218
+ st.session_state.input_tokens_count = 0
219
+ if 'output_tokens_count' not in st.session_state:
220
+ st.session_state.output_tokens_count = 0
221
+
222
  # ==============================
223
  # Main Streamlit UI and Workflow
224
  # ==============================
 
227
 
228
  uploaded_file = st.file_uploader("πŸ“ Upload your document (.txt)", type="txt")
229
  doc_text = None
230
+ doc_text_count = None
231
  cache = None
232
  origin_len = None
233
  last_generation_time = None
234
+ last_cache_clone_time = None
235
  t1 = time()
236
  if uploaded_file:
237
  temp_file_path = "temp_document.txt"
238
  with open(temp_file_path, "wb") as f:
239
  f.write(uploaded_file.getvalue())
240
+ cache, origin_len, doc_text, doc_text_count = load_document_and_cache(temp_file_path)
241
  with st.expander("πŸ“„ Document Preview"):
242
  st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
243
  query = st.text_input("πŸ”Ž Ask a question about the document:")
244
  if query and st.button("Generate Answer"):
245
+ with st.spinner("Generating answer... (watch the sidebar for memory usage)"):
246
+ model, tokenizer = load_model_and_tokenizer(doc_text_count)
247
+ st.sidebar.write(f"Document character count: {len(doc_text)}")
248
  current_cache = clone_cache(cache)
249
  t_clone_end = time()
250
  Cache_create_time = t_clone_end - t1
 
254
  <|assistant|>
255
  """.strip()
256
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
257
+ st.session_state.input_tokens_count += input_ids.shape[-1]
258
  t_gen_start = time()
259
  output_ids = generate(model, input_ids, current_cache)
260
+ generated_tokens_count = output_ids.shape[-1]
261
+ st.session_state.generated_tokens_count += generated_tokens_count
262
+ st.session_state.output_tokens_count = generated_tokens_count
263
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
264
  t_gen_end = time()
265
  last_generation_time = t_gen_end - t_gen_start
 
266
  st.success("Answer:")
267
  st.write(response)
268
+ st.info(f"Cache create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s")
 
 
269
  if st.button("πŸ’Ύ Save Cache"):
270
  torch.save(clean_up(clone_cache(cache), origin_len), "saved_cache.pth")
271
  st.success("Cache saved successfully!")
272
+ resource_dashboard(cache, doc_text, last_generation_time, last_cache_clone_time)
 
 
 
 
 
 
 
 
 
273
  else:
274
  st.info("Please upload a document to start.")
275
+ resource_dashboard(None, None)
276
 
277
  # Sidebar: Load a previously saved cache
278
  st.sidebar.header("πŸ› οΈ Advanced Options")