Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -119,13 +119,14 @@ def clean_up(cache, origin_len):
|
|
| 119 |
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
|
| 120 |
return cache
|
| 121 |
|
| 122 |
-
def
|
|
|
|
| 123 |
total_memory = 0
|
| 124 |
-
for key in
|
| 125 |
total_memory += key.element_size() * key.nelement()
|
| 126 |
-
for value in
|
| 127 |
total_memory += value.element_size() * value.nelement()
|
| 128 |
-
return total_memory/(1024*1024)
|
| 129 |
|
| 130 |
@st.cache_resource
|
| 131 |
def load_model_and_tokenizer(doc_text_count):
|
|
@@ -198,7 +199,7 @@ if uploaded_file:
|
|
| 198 |
st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
|
| 199 |
query = st.text_input("π Ask a question about the document:")
|
| 200 |
if query and st.button("Generate Answer"):
|
| 201 |
-
with st.spinner("Generating answer...
|
| 202 |
current_cache = clone_cache(cache)
|
| 203 |
t_clone_end = time()
|
| 204 |
Cache_create_time = t_clone_end - t1
|
|
@@ -213,7 +214,7 @@ if uploaded_file:
|
|
| 213 |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 214 |
t_gen_end = time()
|
| 215 |
last_generation_time = t_gen_end - t_gen_start
|
| 216 |
-
cache_mem_bytes = cache
|
| 217 |
st.success("Answer:")
|
| 218 |
st.write(response)
|
| 219 |
st.info(f"Cache create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s ")
|
|
|
|
| 119 |
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
|
| 120 |
return cache
|
| 121 |
|
| 122 |
+
def calculate_cache_size(cache):
|
| 123 |
+
"""Calculate the total memory used by the key-value cache in bytes."""
|
| 124 |
total_memory = 0
|
| 125 |
+
for key in cache.key_cache:
|
| 126 |
total_memory += key.element_size() * key.nelement()
|
| 127 |
+
for value in cache.value_cache:
|
| 128 |
total_memory += value.element_size() * value.nelement()
|
| 129 |
+
return total_memory /(1024*1024)
|
| 130 |
|
| 131 |
@st.cache_resource
|
| 132 |
def load_model_and_tokenizer(doc_text_count):
|
|
|
|
| 199 |
st.text(doc_text[:500] + "..." if len(doc_text) > 500 else doc_text)
|
| 200 |
query = st.text_input("π Ask a question about the document:")
|
| 201 |
if query and st.button("Generate Answer"):
|
| 202 |
+
with st.spinner("Generating answer..."):
|
| 203 |
current_cache = clone_cache(cache)
|
| 204 |
t_clone_end = time()
|
| 205 |
Cache_create_time = t_clone_end - t1
|
|
|
|
| 214 |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 215 |
t_gen_end = time()
|
| 216 |
last_generation_time = t_gen_end - t_gen_start
|
| 217 |
+
cache_mem_bytes = calculate_cache_size(cache)
|
| 218 |
st.success("Answer:")
|
| 219 |
st.write(response)
|
| 220 |
st.info(f"Cache create Time: {Cache_create_time:.2f} s | Generation Time: {last_generation_time:.2f} s ")
|