Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers.cache_utils import DynamicCache | |
| import os | |
| from time import time | |
| import pandas as pd | |
| # ============================== | |
| # Helper: Human-readable bytes | |
| def sizeof_fmt(num, suffix="B"): | |
| # Formats bytes as human-readable (e.g. 1.5 GB) | |
| for unit in ["", "K", "M", "G", "T"]: | |
| if abs(num) < 1024.0: | |
| return f"{num:3.2f} {unit}{suffix}" | |
| num /= 1024.0 | |
| return f"{num:.2f} P{suffix}" | |
| # ============================== | |
| # Core Model and Caching Logic | |
| # ============================== | |
| def generate(model, input_ids, past_key_values, max_new_tokens): | |
| """Token-by-token generation using cache for speed.""" | |
| device = model.model.embed_tokens.weight.device | |
| origin_len = input_ids.shape[-1] | |
| input_ids = input_ids.to(device) | |
| output_ids = input_ids.clone() | |
| next_token = input_ids | |
| with torch.no_grad(): | |
| for _ in range(128): | |
| out = model( | |
| input_ids=next_token, | |
| past_key_values=past_key_values, | |
| use_cache=True | |
| ) | |
| logits = out.logits[:, -1, :] | |
| token = torch.argmax(logits, dim=-1, keepdim=True) | |
| output_ids = torch.cat([output_ids, token], dim=-1) | |
| past_key_values = out.past_key_values | |
| next_token = token.to(device) | |
| if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: | |
| break | |
| return output_ids[:, origin_len:] | |
| def get_kv_cache(model, tokenizer, prompt): | |
| """Prepares and stores the key-value cache for the initial document/context.""" | |
| device = model.model.embed_tokens.weight.device | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| cache = DynamicCache() | |
| with torch.no_grad(): | |
| _ = model( | |
| input_ids=input_ids, | |
| past_key_values=cache, | |
| use_cache=True | |
| ) | |
| return cache, input_ids.shape[-1] | |
| def clean_up(cache, origin_len): | |
| """Trims the cache to only include the original document/context tokens.""" | |
| for i in range(len(cache.key_cache)): | |
| cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :] | |
| cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :] | |
| return cache | |
| def calculate_cache_size(cache): | |
| """Calculate the total memory used by the key-value cache in bytes.""" | |
| total_memory = 0 | |
| for key in cache.key_cache: | |
| total_memory += key.element_size() * key.nelement() | |
| for value in cache.value_cache: | |
| total_memory += value.element_size() * value.nelement() | |
| return total_memory /(1024*1024) | |
| def load_model_and_tokenizer(): | |
| model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| return model, tokenizer | |
| def calculate_cache_size(cache): | |
| """ | |
| Calculate the total memory used by the key-value cache (past_key_values) in megabytes. | |
| Args: | |
| cache: The past_key_values object (usually a tuple of (key, value) pairs per layer). | |
| Returns: | |
| Total memory in megabytes. | |
| """ | |
| total_memory = 0 | |
| for layer_cache in cache: | |
| key_tensor, value_tensor = layer_cache | |
| total_memory += key_tensor.element_size() * key_tensor.nelement() | |
| total_memory += value_tensor.element_size() * value_tensor.nelement() | |
| return total_memory / (1024 * 1024) # Convert to MB | |
| def clone_cache(cache): | |
| new_cache = DynamicCache() | |
| for key, value in zip(cache.key_cache, cache.value_cache): | |
| new_cache.key_cache.append(key.clone()) | |
| new_cache.value_cache.append(value.clone()) | |
| return new_cache | |
| def load_document_and_cache(file_path): | |
| try: | |
| t2 = time() | |
| with open(file_path, 'r') as file: | |
| doc_text = file.read() | |
| doc_text_count = len(doc_text) | |
| max_length = int(1.3 * (doc_text_count * 0.3 + 1)) | |
| # Cap the value at 16824 | |
| if max_length > 16824: | |
| max_length = 16824 | |
| print(f" model_max_length set to: {max_length}") | |
| model, tokenizer = load_model_and_tokenizer() | |
| tokenizer.model_max_length=max_length | |
| system_prompt = f""" | |
| <|system|> | |
| Answer concisely and precisely. You are an assistant who provides concise factual answers. | |
| <|user|> | |
| Context: | |
| {doc_text} | |
| Question: | |
| """.strip() | |
| cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) | |
| t3 = time() | |
| print(f"{t3-t2}") | |
| return cache,doc_text, doc_text_count, model, tokenizer | |
| except FileNotFoundError: | |
| st.error(f"Document file not found at {file_path}") | |
| return None, None, None, None | |
| # ============================== | |
| # Streamlit UI | |
| # ============================== | |
| # Initialize token counters | |
| input_tokens_count = 0 | |
| generated_tokens_count = 0 | |
| output_tokens_count = 0 | |
| # Reset counters with a button | |
| if st.button("π Reset Token Counters"): | |
| input_tokens_count = 0 | |
| generated_tokens_count = 0 | |
| output_tokens_count = 0 | |
| doc_text = None | |
| cache = None | |
| model = None | |
| tokenizer = None | |
| st.success("Token counters have been reset.") | |
| st.title("π DeepSeek QA: Supercharged Caching & Memory Dashboard") | |
| uploaded_file = st.file_uploader("π Upload your document (.txt)", type="txt") | |
| # Initialize variables | |
| doc_text = None | |
| cache = None | |
| model = None | |
| tokenizer = None | |
| if uploaded_file: | |
| log = [] | |
| # PART 1: File Upload & Save | |
| t_start1 = time() | |
| temp_file_path = "temp_document.txt" | |
| with open(temp_file_path, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| t_end1 = time() | |
| log.append(f"π File Upload & Save Time: {t_end1 - t_start1:.2f} s") | |
| # print(f"π File Upload & Save Time: {t_end1 - t_start1:.2f} s") | |
| # PART 2: Document and Cache Load | |
| t_start2 = time() | |
| cache, doc_text,doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path) | |
| t_end2 = time() | |
| log.append(f"π Document & Cache Load Time: {t_end2 - t_start2:.2f} s") | |
| # print(f"π Document & Cache Load Time: {t_end2 - t_start2:.2f} s") | |
| # PART 3: Document Preview Display | |
| t_start3 = time() | |
| with st.expander("π Document Preview"): | |
| preview = doc_text[:500] + "..." if len(doc_text) > 500 else doc_text | |
| st.text(preview) | |
| t_end3 = time() | |
| log.append(f"π Document Preview Display Time: {t_end3 - t_start3:.2f} s") | |
| # print(f"π Document Preview Display Time: {t_end3 - t_start3:.2f} s") | |
| t_start4 = time() | |
| # PART 4: Show Basic Info | |
| s_cache=calculate_cache_size(cache) | |
| t_end4 = time() | |
| log.append(f"π doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s") | |
| # print(f"π doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s||||||| size of the cache : {s_cache} MB") | |
| #st.info( | |
| # f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | " | |
| # f"Cache Size: {cache_size if cache_size == 'N/A' else f'{cache_size:.2f} KB'}" | |
| #) | |
| # ========================= | |
| # User Query and Generation | |
| # ========================= | |
| query = st.text_input("π Ask a question about the document:") | |
| if query and st.button("Generate Answer"): | |
| with st.spinner("Generating answer..."): | |
| log.append("π Query & Generation Steps:") | |
| # PART 4.1: Clone Cache | |
| t_start5 = time() | |
| current_cache = clone_cache(cache) | |
| t_end5 = time() | |
| # print(f"π Clone Cache Time: {t_end5 - t_start5:.2f} s") | |
| log.append(f"π Clone Cache Time: {t_end5 - t_start5:.2f} s") | |
| # PART 4.2: Tokenize Prompt | |
| t_start6 = time() | |
| full_prompt = f""" | |
| <|user|> | |
| Question: {query} | |
| <|assistant|> | |
| """.strip() | |
| input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids | |
| input_tokens_count += input_ids.shape[-1] | |
| t_end6 = time() | |
| print(f"βοΈ Tokenization Time: {t_end6 - t_start6:.2f} s") | |
| log.append(f"βοΈ Tokenization Time: {t_end6 - t_start6:.2f} s") | |
| # PART 4.3: Generate Answer | |
| t_start7 = time() | |
| output_ids = generate(model, input_ids, current_cache, max_new_tokens=4) | |
| last_generation_time = time() - t_start7 | |
| #print(f"π‘ Generation Time: {last_generation_time:.2f} s") | |
| log.append(f"π‘ Generation Time: {last_generation_time:.2f} s") | |
| generated_tokens_count = output_ids.shape[-1] | |
| generated_tokens_count += generated_tokens_count | |
| output_tokens_count = generated_tokens_count | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| st.success("Answer:") | |
| st.write(response) | |
| print(f"***************************************************************************************") | |
| # Final Info Display | |
| # st.info( | |
| # f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | " | |
| # f"Cache Clone Time: {log[-3].split(': ')[1]} | Generation Time: {last_generation_time:.2f} s" | |
| # ) | |
| # ========================= | |
| # Show Log | |
| # ========================= | |
| st.sidebar.header("π Performance Log") | |
| for entry in log: | |
| st.sidebar.write(entry) | |
| # ========================= | |
| # Sidebar: Cache Loader | |
| # ========================= | |
| st.sidebar.header("π οΈ Advanced Options") | |
| st.sidebar.write("Load a previously saved cache for instant document context reuse.") | |
| if st.sidebar.checkbox("Load saved cache"): | |
| cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth") | |
| if cache_file: | |
| with open("temp_cache.pth", "wb") as f: | |
| f.write(cache_file.getvalue()) | |
| try: | |
| loaded_cache = torch.load("temp_cache.pth") | |
| cache = loaded_cache | |
| st.sidebar.success("Cache loaded successfully!") | |
| except Exception as e: | |
| st.sidebar.error(f"Failed to load cache file: {e}") |