import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import time import gc import os import psutil # Configuration BASE_MODEL = "microsoft/phi-2" ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant" DEBUG = False # Set to True to enable debug prints # Memory monitoring def get_memory_usage(): process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 * 1024) # MB class ModelWrapper: def __init__(self): self.model = None self.tokenizer = None self.loaded = False def load_model(self): if not self.loaded: try: # Force CPU usage os.environ["CUDA_VISIBLE_DEVICES"] = "" device = torch.device("cpu") # Clear memory gc.collect() if DEBUG: print(f"Memory before loading: {get_memory_usage():.2f} MB") print("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL, trust_remote_code=True, padding_side="left" ) self.tokenizer.pad_token = self.tokenizer.eos_token if DEBUG: print(f"Memory after tokenizer: {get_memory_usage():.2f} MB") print("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, use_flash_attention_2=False, low_cpu_mem_usage=True, offload_folder="offload" ) if DEBUG: print(f"Memory after base model: {get_memory_usage():.2f} MB") print("Loading LoRA adapter...") self.model = PeftModel.from_pretrained( base_model, ADAPTER_MODEL, torch_dtype=torch.float32, device_map="cpu" ) # Free up memory del base_model gc.collect() if DEBUG: print(f"Memory after adapter: {get_memory_usage():.2f} MB") self.model.eval() print("Model loading complete!") self.loaded = True except Exception as e: print(f"Error during model loading: {str(e)}") raise def generate_response(self, prompt, max_length=256, temperature=0.7, top_p=0.9): if not self.loaded: self.load_model() try: # Use shorter prompts to save memory if "function" in prompt.lower() and "python" in prompt.lower(): enhanced_prompt = f"""Write Python function: {prompt}""" elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]): enhanced_prompt = f"""Explain briefly: {prompt}""" else: enhanced_prompt = prompt if DEBUG: print(f"Enhanced prompt: {enhanced_prompt}") # Tokenize input with shorter max length inputs = self.tokenizer( enhanced_prompt, return_tensors="pt", truncation=True, max_length=256, # Reduced for memory padding=True ).to("cpu") # Generate with minimal parameters start_time = time.time() with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=min(max_length, 256), # Strict limit min_length=10, # Reduced minimum temperature=min(0.5, temperature), top_p=min(0.85, top_p), do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.2, no_repeat_ngram_size=3, num_return_sequences=1, early_stopping=True, num_beams=1, # Greedy decoding to save memory length_penalty=0.6 ) # Decode response response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if DEBUG: print(f"Raw response: {response}") # Clean up the response if response.startswith(enhanced_prompt): response = response[len(enhanced_prompt):].strip() if DEBUG: print(f"After prompt removal: {response}") # Basic cleanup only cleaned_response = response.replace("Human:", "").replace("Assistant:", "") if DEBUG and cleaned_response != response: print(f"After conversation removal: {cleaned_response}") response = cleaned_response # Ensure code examples are properly formatted if "```python" not in response and "def " in response: response = "```python\n" + response + "\n```" # Simple validation if len(response.strip()) < 10: if DEBUG: print("Response validation failed - using fallback") if "function" in prompt.lower(): fallback_response = """```python def add_numbers(a, b): return a + b ```""" else: fallback_response = "I apologize, but I couldn't generate a response. Please try with a simpler prompt." response = fallback_response # Clear memory after generation gc.collect() generation_time = time.time() - start_time return response, generation_time except Exception as e: print(f"Error during generation: {str(e)}") raise # Initialize model wrapper model_wrapper = ModelWrapper() def generate_text(prompt, max_length=256, temperature=0.5, top_p=0.85): """Gradio interface function""" try: if not prompt.strip(): return "Please enter a prompt." response, gen_time = model_wrapper.generate_response( prompt, max_length=max_length, temperature=temperature, top_p=top_p ) return f"Generated in {gen_time:.2f} seconds:\n\n{response}" except Exception as e: print(f"Error in generate_text: {str(e)}") return f"Error generating response: {str(e)}\nPlease try again with a shorter prompt." # Create a very lightweight Gradio interface demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox( label="Enter your prompt", placeholder="Type your prompt here...", lines=3 ), gr.Slider( minimum=64, maximum=256, value=192, step=32, label="Maximum Length", info="Keep this low for CPU" ), gr.Slider( minimum=0.1, maximum=0.7, value=0.4, step=0.1, label="Temperature", info="Lower is better for CPU" ), gr.Slider( minimum=0.5, maximum=0.9, value=0.8, step=0.1, label="Top P", info="Controls diversity" ), ], outputs=gr.Textbox(label="Generated Response", lines=6), title="Phi-2 QLoRA Assistant (CPU-Optimized)", description="""This is a lightweight CPU version of the fine-tuned Phi-2 model. Tips: - Keep prompts short and specific - Use lower maximum length (128-192) for faster responses - Use lower temperature (0.3-0.5) for more reliable responses """, examples=[ [ "Write a Python function to calculate factorial", 192, 0.4, 0.8 ], [ "Explain machine learning simply", 192, 0.4, 0.8 ], [ "Write a short email to schedule a meeting", 192, 0.4, 0.8 ] ], cache_examples=False, concurrency_limit=1 # Use the correct parameter for limiting concurrency ) if __name__ == "__main__": # Using the modern approach without queue method demo.launch(max_threads=1) # Limit the number of worker threads