QLORA_phi2 / app_fixed.py
pradeep6kumar2024's picture
Fix Gradio version and remove debug prints
61052e7
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import time
import gc
import os
import psutil
# Configuration
BASE_MODEL = "microsoft/phi-2"
ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant"
DEBUG = False # Set to True to enable debug prints
# Memory monitoring
def get_memory_usage():
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
class ModelWrapper:
def __init__(self):
self.model = None
self.tokenizer = None
self.loaded = False
def load_model(self):
if not self.loaded:
try:
# Force CPU usage
os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = torch.device("cpu")
# Clear memory
gc.collect()
if DEBUG:
print(f"Memory before loading: {get_memory_usage():.2f} MB")
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
padding_side="left"
)
self.tokenizer.pad_token = self.tokenizer.eos_token
if DEBUG:
print(f"Memory after tokenizer: {get_memory_usage():.2f} MB")
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
use_flash_attention_2=False,
low_cpu_mem_usage=True,
offload_folder="offload"
)
if DEBUG:
print(f"Memory after base model: {get_memory_usage():.2f} MB")
print("Loading LoRA adapter...")
self.model = PeftModel.from_pretrained(
base_model,
ADAPTER_MODEL,
torch_dtype=torch.float32,
device_map="cpu"
)
# Free up memory
del base_model
gc.collect()
if DEBUG:
print(f"Memory after adapter: {get_memory_usage():.2f} MB")
self.model.eval()
print("Model loading complete!")
self.loaded = True
except Exception as e:
print(f"Error during model loading: {str(e)}")
raise
def generate_response(self, prompt, max_length=256, temperature=0.7, top_p=0.9):
if not self.loaded:
self.load_model()
try:
# Use shorter prompts to save memory
if "function" in prompt.lower() and "python" in prompt.lower():
enhanced_prompt = f"""Write Python function: {prompt}"""
elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]):
enhanced_prompt = f"""Explain briefly: {prompt}"""
else:
enhanced_prompt = prompt
if DEBUG:
print(f"Enhanced prompt: {enhanced_prompt}")
# Tokenize input with shorter max length
inputs = self.tokenizer(
enhanced_prompt,
return_tensors="pt",
truncation=True,
max_length=256, # Reduced for memory
padding=True
).to("cpu")
# Generate with minimal parameters
start_time = time.time()
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=min(max_length, 256), # Strict limit
min_length=10, # Reduced minimum
temperature=min(0.5, temperature),
top_p=min(0.85, top_p),
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
num_return_sequences=1,
early_stopping=True,
num_beams=1, # Greedy decoding to save memory
length_penalty=0.6
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if DEBUG:
print(f"Raw response: {response}")
# Clean up the response
if response.startswith(enhanced_prompt):
response = response[len(enhanced_prompt):].strip()
if DEBUG:
print(f"After prompt removal: {response}")
# Basic cleanup only
cleaned_response = response.replace("Human:", "").replace("Assistant:", "")
if DEBUG and cleaned_response != response:
print(f"After conversation removal: {cleaned_response}")
response = cleaned_response
# Ensure code examples are properly formatted
if "```python" not in response and "def " in response:
response = "```python\n" + response + "\n```"
# Simple validation
if len(response.strip()) < 10:
if DEBUG:
print("Response validation failed - using fallback")
if "function" in prompt.lower():
fallback_response = """```python
def add_numbers(a, b):
return a + b
```"""
else:
fallback_response = "I apologize, but I couldn't generate a response. Please try with a simpler prompt."
response = fallback_response
# Clear memory after generation
gc.collect()
generation_time = time.time() - start_time
return response, generation_time
except Exception as e:
print(f"Error during generation: {str(e)}")
raise
# Initialize model wrapper
model_wrapper = ModelWrapper()
def generate_text(prompt, max_length=256, temperature=0.5, top_p=0.85):
"""Gradio interface function"""
try:
if not prompt.strip():
return "Please enter a prompt."
response, gen_time = model_wrapper.generate_response(
prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p
)
return f"Generated in {gen_time:.2f} seconds:\n\n{response}"
except Exception as e:
print(f"Error in generate_text: {str(e)}")
return f"Error generating response: {str(e)}\nPlease try again with a shorter prompt."
# Create a very lightweight Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
label="Enter your prompt",
placeholder="Type your prompt here...",
lines=3
),
gr.Slider(
minimum=64,
maximum=256,
value=192,
step=32,
label="Maximum Length",
info="Keep this low for CPU"
),
gr.Slider(
minimum=0.1,
maximum=0.7,
value=0.4,
step=0.1,
label="Temperature",
info="Lower is better for CPU"
),
gr.Slider(
minimum=0.5,
maximum=0.9,
value=0.8,
step=0.1,
label="Top P",
info="Controls diversity"
),
],
outputs=gr.Textbox(label="Generated Response", lines=6),
title="Phi-2 QLoRA Assistant (CPU-Optimized)",
description="""This is a lightweight CPU version of the fine-tuned Phi-2 model.
Tips:
- Keep prompts short and specific
- Use lower maximum length (128-192) for faster responses
- Use lower temperature (0.3-0.5) for more reliable responses
""",
examples=[
[
"Write a Python function to calculate factorial",
192,
0.4,
0.8
],
[
"Explain machine learning simply",
192,
0.4,
0.8
],
[
"Write a short email to schedule a meeting",
192,
0.4,
0.8
]
],
cache_examples=False,
concurrency_limit=1 # Use the correct parameter for limiting concurrency
)
if __name__ == "__main__":
# Using the modern approach without queue method
demo.launch(max_threads=1) # Limit the number of worker threads