kouki321 commited on
Commit
25dcf7e
Β·
verified Β·
1 Parent(s): c42bfcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -103
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers.cache_utils import DynamicCache
5
  import os
6
  from time import time
7
  import pandas as pd
8
- import psutil
9
 
10
 
11
  # ==============================
@@ -18,7 +17,6 @@ def sizeof_fmt(num, suffix="B"):
18
  num /= 1024.0
19
  return f"{num:.2f} P{suffix}"
20
 
21
-
22
  # ==============================
23
  # Core Model and Caching Logic
24
  # ==============================
@@ -81,7 +79,7 @@ def load_model_and_tokenizer(doc_text_count):
81
  tokenizer = AutoTokenizer.from_pretrained(
82
  model_name,
83
  trust_remote_code=True,
84
- model_max_length=2*round(doc_text_count * 0.3 + 1)
85
  )
86
  model = AutoModelForCausalLM.from_pretrained(
87
  model_name,
@@ -122,95 +120,6 @@ def load_document_and_cache(file_path):
122
  st.error(f"Document file not found at {file_path}")
123
  return None, None, None, None, None, None
124
 
125
- # ==============================
126
- # System & Cache Resource Stats
127
- # ==============================
128
- def get_system_stats(doc_text=None, cache_mem_bytes=0):
129
- ram = psutil.virtual_memory()
130
- cpu = psutil.cpu_percent()
131
- disk = psutil.disk_usage('/')
132
- used, total = ram.used, ram.total
133
- stats = {
134
- "Input Tokens": st.session_state.get('input_tokens_count', 0),
135
- "Output Tokens": st.session_state.get('output_tokens_count', 0),
136
- "Generated Tokens": st.session_state.get('generated_tokens_count', 0),
137
- "Document Size (chars)": len(doc_text) if doc_text else 0,
138
- "Document Size (KB)": f"{len(doc_text.encode('utf-8')) / 1024:.2f}" if doc_text else 0,
139
- }
140
- if torch.cuda.is_available():
141
- gpu_mem_alloc = torch.cuda.memory_allocated()
142
- gpu_mem_total = torch.cuda.get_device_properties(0).total_memory
143
- stats["GPU Used"] = sizeof_fmt(gpu_mem_alloc)
144
- stats["GPU Total"] = sizeof_fmt(gpu_mem_total)
145
- stats["GPU Usage (%)"] = round(100 * gpu_mem_alloc / gpu_mem_total, 2) if gpu_mem_total else 0
146
- else:
147
- stats["GPU Used"] = "N/A"
148
- stats["GPU Total"] = "N/A"
149
- stats["GPU Usage (%)"] = "N/A"
150
-
151
- stats["KV Cache Memory Used"] = sizeof_fmt(cache_mem_bytes)
152
- stats["KV Cache as % RAM"] = f"{(cache_mem_bytes / total) * 100:.2f}%" if total > 0 else "N/A"
153
- stats["KV Cache as % GPU"] = (
154
- f"{(cache_mem_bytes / torch.cuda.get_device_properties(0).total_memory) * 100:.2f}%"
155
- if torch.cuda.is_available() else "N/A"
156
- )
157
- return stats
158
-
159
- def cache_stats_table(cache):
160
- if cache is None:
161
- return pd.DataFrame(), 0
162
- rows = []
163
- total_mem = 0
164
- for i, (key, value) in enumerate(zip(cache.key_cache, cache.value_cache)):
165
- key_mem = key.element_size() * key.nelement()
166
- value_mem = value.element_size() * value.nelement()
167
- total_mem += key_mem + value_mem
168
- row = {
169
- "Layer": i,
170
- "Key Shape": str(tuple(key.shape)),
171
- "Value Shape": str(tuple(value.shape)),
172
- "Total Mem": sizeof_fmt(key_mem + value_mem),
173
- "Last Key Tokens": str(tuple(key[..., -1:, :].shape)),
174
- "Last Value Tokens": str(tuple(value[..., -1:, :].shape)),
175
- }
176
- rows.append(row)
177
- return pd.DataFrame(rows), total_mem
178
-
179
- def resource_dashboard(cache, doc_text, generation_time=None, cache_clone_time=None):
180
- cache_df, cache_mem_bytes = cache_stats_table(cache)
181
- stats = get_system_stats(doc_text, cache_mem_bytes)
182
- st.sidebar.header("🚦 Live Resource & Cache Dashboard")
183
- st.sidebar.caption("See how your document and answers use your computer's memory and processing power. The KV Cache lets you answer questions super-fast!")
184
- # Use st.table for small stats tables for better rendering
185
- stats_table = pd.DataFrame(list(stats.items()), columns=["Metric", "Value"])
186
- st.sidebar.table(stats_table)
187
- if torch.cuda.is_available() and stats["GPU Usage (%)"] != "N/A":
188
- gpu_pct = float(stats["GPU Usage (%)"])
189
- st.sidebar.progress(int(min(gpu_pct, 100)), text=f"GPU Usage: {gpu_pct:.1f}%")
190
- cache_pct_str = stats["KV Cache as % RAM"]
191
- if isinstance(cache_pct_str, str) and cache_pct_str.endswith('%'):
192
- try:
193
- cache_pct = float(cache_pct_str[:-1])
194
- except ValueError:
195
- cache_pct = 0
196
- else:
197
- cache_pct = 0
198
- st.sidebar.progress(int(min(cache_pct, 100)), text=f"KV Cache as RAM: {cache_pct:.1f}%")
199
- if generation_time is not None or cache_clone_time is not None:
200
- time_rows = []
201
- if generation_time is not None:
202
- time_rows.append({"Step": "Answer Generation", "Time (s)": f"{generation_time:.2f}"})
203
- if cache_clone_time is not None:
204
- time_rows.append({"Step": "Cache Copy", "Time (s)": f"{cache_clone_time:.2f}"})
205
- st.sidebar.table(pd.DataFrame(time_rows))
206
- with st.sidebar.expander("🧠 KV Cache Details (per Layer)", expanded=True):
207
- st.markdown(
208
- "The table below shows the shape, dtype, size, and memory used for each layer's cache in the neural network. Efficient caching speeds up new questions."
209
- )
210
- if not cache_df.empty:
211
- st.dataframe(cache_df, use_container_width=True, height=340)
212
- else:
213
- st.info("No cache yet. Upload a document to see caching details.")
214
 
215
  # Initialize session state variables
216
  if 'generated_tokens_count' not in st.session_state:
@@ -241,15 +150,33 @@ if uploaded_file:
241
  with open(temp_file_path, "wb") as f:
242
  f.write(uploaded_file.getvalue())
243
  cache, origin_len, doc_text, doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path)
 
 
244
  with st.expander("πŸ“„ Document Preview"):
245
  st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  query = st.text_input("πŸ”Ž Ask a question about the document:")
247
  if query and st.button("Generate Answer"):
248
  with st.spinner("Generating answer... (watch the sidebar for memory usage)"):
249
- st.sidebar.write(f"Document character count: {len(doc_text)}")
250
  current_cache = clone_cache(cache)
251
  t_clone_end = time()
252
  Cache_create_time = t_clone_end - t1
 
253
  full_prompt = f"""
254
  <|user|>
255
  Question: {query}
@@ -257,8 +184,8 @@ if uploaded_file:
257
  """.strip()
258
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
259
  max_new_tokens = max(32, int(input_ids.shape[-1] * 0.3))
260
- print(f"Max new tokens: {max_new_tokens}")
261
  st.session_state.input_tokens_count += input_ids.shape[-1]
 
262
  t_gen_start = time()
263
  output_ids = generate(model, input_ids, current_cache, max_new_tokens=max_new_tokens)
264
  generated_tokens_count = output_ids.shape[-1]
@@ -267,17 +194,19 @@ if uploaded_file:
267
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
268
  t_gen_end = time()
269
  last_generation_time = t_gen_end - t_gen_start
 
270
  st.success("Answer:")
271
  st.write(response)
272
- st.info(f"Cache create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s")
273
- if st.button("πŸ’Ύ Save Cache"):
274
- torch.save(clean_up(clone_cache(cache), origin_len), "saved_cache.pth")
275
- st.success("Cache saved successfully!")
276
- # Add Reset button at the end
277
- resource_dashboard(cache, doc_text, last_generation_time, last_cache_clone_time)
278
- else:
279
- st.info("Please upload a document to start.")
280
- resource_dashboard(None, None)
 
281
 
282
  # Sidebar: Load a previously saved cache
283
  st.sidebar.header("πŸ› οΈ Advanced Options")
 
5
  import os
6
  from time import time
7
  import pandas as pd
 
8
 
9
 
10
  # ==============================
 
17
  num /= 1024.0
18
  return f"{num:.2f} P{suffix}"
19
 
 
20
  # ==============================
21
  # Core Model and Caching Logic
22
  # ==============================
 
79
  tokenizer = AutoTokenizer.from_pretrained(
80
  model_name,
81
  trust_remote_code=True,
82
+ model_max_length=1.3*round(doc_text_count * 0.3 + 1)
83
  )
84
  model = AutoModelForCausalLM.from_pretrained(
85
  model_name,
 
120
  st.error(f"Document file not found at {file_path}")
121
  return None, None, None, None, None, None
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # Initialize session state variables
125
  if 'generated_tokens_count' not in st.session_state:
 
150
  with open(temp_file_path, "wb") as f:
151
  f.write(uploaded_file.getvalue())
152
  cache, origin_len, doc_text, doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path)
153
+
154
+ # Document Info Display
155
  with st.expander("πŸ“„ Document Preview"):
156
  st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
157
+
158
+ # Collect System Stats AFTER doc upload
159
+ cache_df, cache_mem_bytes = cache_stats_table(cache)
160
+ stats = get_system_stats(doc_text, cache_mem_bytes)
161
+
162
+ # Track Time
163
+ t1 = time()
164
+
165
+ # Generate Info Line (Initial)
166
+ st.info(
167
+ f"Document Chars: {len(doc_text)} | Size: {stats['Document Size (KB)']} KB | "
168
+ f"GPU Used: {stats['GPU Used']} | GPU Usage: {stats['GPU Usage (%)']}% | "
169
+ f"KV Cache Memory: {stats['KV Cache Memory Used']} | "
170
+ f"Cache as % RAM: {stats['KV Cache as % RAM']} | Cache as % GPU: {stats['KV Cache as % GPU']}"
171
+ )
172
+
173
  query = st.text_input("πŸ”Ž Ask a question about the document:")
174
  if query and st.button("Generate Answer"):
175
  with st.spinner("Generating answer... (watch the sidebar for memory usage)"):
 
176
  current_cache = clone_cache(cache)
177
  t_clone_end = time()
178
  Cache_create_time = t_clone_end - t1
179
+
180
  full_prompt = f"""
181
  <|user|>
182
  Question: {query}
 
184
  """.strip()
185
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
186
  max_new_tokens = max(32, int(input_ids.shape[-1] * 0.3))
 
187
  st.session_state.input_tokens_count += input_ids.shape[-1]
188
+
189
  t_gen_start = time()
190
  output_ids = generate(model, input_ids, current_cache, max_new_tokens=max_new_tokens)
191
  generated_tokens_count = output_ids.shape[-1]
 
194
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
195
  t_gen_end = time()
196
  last_generation_time = t_gen_end - t_gen_start
197
+
198
  st.success("Answer:")
199
  st.write(response)
200
+
201
+ # Unified Info Line AFTER Generation
202
+ st.info(
203
+ f"Document Chars: {len(doc_text)} | Size: {stats['Document Size (KB)']} KB | "
204
+ f"GPU Used: {stats['GPU Used']} | GPU Usage: {stats['GPU Usage (%)']}% | "
205
+ f"KV Cache Memory: {stats['KV Cache Memory Used']} | "
206
+ f"Cache as % RAM: {stats['KV Cache as % RAM']} | Cache as % GPU: {stats['KV Cache as % GPU']} | "
207
+ f"Cache Create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s"
208
+ )
209
+
210
 
211
  # Sidebar: Load a previously saved cache
212
  st.sidebar.header("πŸ› οΈ Advanced Options")