Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel | |
import time | |
import gc | |
import os | |
import psutil | |
# Configuration | |
BASE_MODEL = "microsoft/phi-2" | |
ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant" | |
# Memory monitoring | |
def get_memory_usage(): | |
process = psutil.Process(os.getpid()) | |
return process.memory_info().rss / (1024 * 1024) # MB | |
class ModelWrapper: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.loaded = False | |
def load_model(self): | |
if not self.loaded: | |
try: | |
# Force CPU usage | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
device = torch.device("cpu") | |
# Clear memory | |
gc.collect() | |
print(f"Memory before loading: {get_memory_usage():.2f} MB") | |
print("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
BASE_MODEL, | |
trust_remote_code=True, | |
padding_side="left" | |
) | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
print(f"Memory after tokenizer: {get_memory_usage():.2f} MB") | |
print("Loading base model...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
torch_dtype=torch.float32, | |
device_map="cpu", | |
trust_remote_code=True, | |
use_flash_attention_2=False, | |
low_cpu_mem_usage=True, | |
offload_folder="offload" | |
) | |
print(f"Memory after base model: {get_memory_usage():.2f} MB") | |
print("Loading LoRA adapter...") | |
self.model = PeftModel.from_pretrained( | |
base_model, | |
ADAPTER_MODEL, | |
torch_dtype=torch.float32, | |
device_map="cpu" | |
) | |
# Free up memory | |
del base_model | |
gc.collect() | |
print(f"Memory after adapter: {get_memory_usage():.2f} MB") | |
self.model.eval() | |
print("Model loading complete!") | |
self.loaded = True | |
except Exception as e: | |
print(f"Error during model loading: {str(e)}") | |
raise | |
def generate_response(self, prompt, max_length=256, temperature=0.7, top_p=0.9): | |
if not self.loaded: | |
self.load_model() | |
try: | |
# Use shorter prompts to save memory | |
if "function" in prompt.lower() and "python" in prompt.lower(): | |
enhanced_prompt = f"""Write Python function: {prompt}""" | |
elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]): | |
enhanced_prompt = f"""Explain briefly: {prompt}""" | |
else: | |
enhanced_prompt = prompt | |
print(f"Enhanced prompt: {enhanced_prompt}") | |
# Tokenize input with shorter max length | |
inputs = self.tokenizer( | |
enhanced_prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=256, # Reduced for memory | |
padding=True | |
).to("cpu") | |
# Generate with minimal parameters | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=min(max_length, 256), # Strict limit | |
min_length=10, # Reduced minimum | |
temperature=min(0.5, temperature), | |
top_p=min(0.85, top_p), | |
do_sample=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=3, | |
num_return_sequences=1, | |
early_stopping=True, | |
num_beams=1, # Greedy decoding to save memory | |
length_penalty=0.6 | |
) | |
# Decode response | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up the response | |
if response.startswith(enhanced_prompt): | |
response = response[len(enhanced_prompt):].strip() | |
# Basic cleanup only | |
response = response.replace("Human:", "").replace("Assistant:", "") | |
# Ensure code examples are properly formatted | |
if "```python" not in response and "def " in response: | |
response = "```python\n" + response + "\n```" | |
# Simple validation | |
if len(response.strip()) < 10: | |
if "function" in prompt.lower(): | |
fallback_response = """```python | |
def add_numbers(a, b): | |
return a + b | |
```""" | |
else: | |
fallback_response = "I apologize, but I couldn't generate a response. Please try with a simpler prompt." | |
response = fallback_response | |
# Clear memory after generation | |
gc.collect() | |
generation_time = time.time() - start_time | |
return response, generation_time | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
raise | |
# Initialize model wrapper | |
model_wrapper = ModelWrapper() | |
def generate_text(prompt, max_length=256, temperature=0.5, top_p=0.85): | |
"""Gradio interface function""" | |
try: | |
if not prompt.strip(): | |
return "Please enter a prompt." | |
response, gen_time = model_wrapper.generate_response( | |
prompt, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p | |
) | |
return f"Generated in {gen_time:.2f} seconds:\n\n{response}" | |
except Exception as e: | |
print(f"Error in generate_text: {str(e)}") | |
return f"Error generating response: {str(e)}\nPlease try again with a shorter prompt." | |
# Create a very lightweight Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Type your prompt here...", | |
lines=3 | |
), | |
gr.Slider( | |
minimum=64, | |
maximum=256, | |
value=192, | |
step=32, | |
label="Maximum Length", | |
info="Keep this low for CPU" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=0.7, | |
value=0.4, | |
step=0.1, | |
label="Temperature", | |
info="Lower is better for CPU" | |
), | |
gr.Slider( | |
minimum=0.5, | |
maximum=0.9, | |
value=0.8, | |
step=0.1, | |
label="Top P", | |
info="Controls diversity" | |
), | |
], | |
outputs=gr.Textbox(label="Generated Response", lines=6), | |
title="Phi-2 QLoRA Assistant (CPU-Optimized)", | |
description="""This is a lightweight CPU version of the fine-tuned Phi-2 model. | |
Tips: | |
- Keep prompts short and specific | |
- Use lower maximum length (128-192) for faster responses | |
- Use lower temperature (0.3-0.5) for more reliable responses | |
""", | |
examples=[ | |
[ | |
"Write a Python function to calculate factorial", | |
192, | |
0.4, | |
0.8 | |
], | |
[ | |
"Explain machine learning simply", | |
192, | |
0.4, | |
0.8 | |
], | |
[ | |
"Write a short email to schedule a meeting", | |
192, | |
0.4, | |
0.8 | |
] | |
], | |
cache_examples=False, | |
concurrency_limit=1 # Use the correct parameter for limiting concurrency | |
) | |
if __name__ == "__main__": | |
demo.launch(max_threads=1) # Limit the number of worker threads |