kouki321 commited on
Commit
6167a87
Β·
verified Β·
1 Parent(s): f5890f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -119,13 +119,14 @@ def clean_up(cache, origin_len):
119
  cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
120
  return cache
121
 
122
- def get_cache_memory(self):
 
123
  total_memory = 0
124
- for key in self.key_cache:
125
  total_memory += key.element_size() * key.nelement()
126
- for value in self.value_cache:
127
  total_memory += value.element_size() * value.nelement()
128
- return total_memory/(1024*1024)
129
 
130
  @st.cache_resource
131
  def load_model_and_tokenizer(doc_text_count):
@@ -198,7 +199,7 @@ if uploaded_file:
198
  st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
199
  query = st.text_input("πŸ”Ž Ask a question about the document:")
200
  if query and st.button("Generate Answer"):
201
- with st.spinner("Generating answer... ()"):
202
  current_cache = clone_cache(cache)
203
  t_clone_end = time()
204
  Cache_create_time = t_clone_end - t1
@@ -213,7 +214,7 @@ if uploaded_file:
213
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
214
  t_gen_end = time()
215
  last_generation_time = t_gen_end - t_gen_start
216
- cache_mem_bytes = cache.get_cache_memory()
217
  st.success("Answer:")
218
  st.write(response)
219
  st.info(f"Cache create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s ")
 
119
  cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
120
  return cache
121
 
122
+ def calculate_cache_size(cache):
123
+ """Calculate the total memory used by the key-value cache in bytes."""
124
  total_memory = 0
125
+ for key in cache.key_cache:
126
  total_memory += key.element_size() * key.nelement()
127
+ for value in cache.value_cache:
128
  total_memory += value.element_size() * value.nelement()
129
+ return total_memory /(1024*1024)
130
 
131
  @st.cache_resource
132
  def load_model_and_tokenizer(doc_text_count):
 
199
  st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
200
  query = st.text_input("πŸ”Ž Ask a question about the document:")
201
  if query and st.button("Generate Answer"):
202
+ with st.spinner("Generating answer..."):
203
  current_cache = clone_cache(cache)
204
  t_clone_end = time()
205
  Cache_create_time = t_clone_end - t1
 
214
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
215
  t_gen_end = time()
216
  last_generation_time = t_gen_end - t_gen_start
217
+ cache_mem_bytes = calculate_cache_size(cache)
218
  st.success("Answer:")
219
  st.write(response)
220
  st.info(f"Cache create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s ")