Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from peft import PeftModel | |
import logging | |
import os | |
from huggingface_hub import snapshot_download | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def download_lora_weights(): | |
"""Download LoRA weights from Hugging Face""" | |
return snapshot_download( | |
repo_id="EmTpro01/Llama-3.2-3B-peft", | |
allow_patterns=["adapter_config.json", "adapter_model.bin"], | |
) | |
def load_model_with_lora(): | |
""" | |
Load Llama model and merge it with LoRA adapter | |
""" | |
try: | |
# Configure quantization | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
# Load base model | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"unsloth/llama-3.2-3b-bnb-4bit", | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
logger.info("Successfully loaded base model") | |
# Download and load LoRA adapter | |
lora_path = download_lora_weights() | |
logger.info(f"Downloaded LoRA weights to: {lora_path}") | |
# Load and merge LoRA adapter | |
model = PeftModel.from_pretrained(base_model, lora_path) | |
logger.info("Successfully loaded LoRA adapter") | |
# For inference, we can merge the LoRA weights with the base model | |
model = model.merge_and_unload() | |
logger.info("Successfully merged LoRA weights with base model") | |
return model | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise RuntimeError(f"Failed to load model: {str(e)}") | |
def load_tokenizer(): | |
""" | |
Load tokenizer for the Llama model | |
""" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3.2-3b-bnb-4bit") | |
logger.info("Successfully loaded tokenizer") | |
return tokenizer | |
except Exception as e: | |
logger.error(f"Error loading tokenizer: {str(e)}") | |
raise RuntimeError(f"Failed to load tokenizer: {str(e)}") | |
def generate_code(prompt, model, tokenizer, max_length=512, temperature=0.7): | |
""" | |
Generate code based on the prompt | |
""" | |
try: | |
# Add any specific prompt template if needed | |
formatted_prompt = f"### Instruction: Write code for the following task:\n{prompt}\n\n### Response:" | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
top_k=50, | |
repetition_penalty=1.1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the response part | |
response = generated_text.split("### Response:")[-1].strip() | |
return response | |
except Exception as e: | |
logger.error(f"Error during code generation: {str(e)}") | |
return f"Error generating code: {str(e)}" | |
# Initialize model and tokenizer | |
logger.info("Starting model initialization...") | |
model = load_model_with_lora() | |
tokenizer = load_tokenizer() | |
logger.info("Model initialization completed successfully") | |
# Create Gradio interface with error handling | |
def gradio_generate(prompt, temperature, max_length): | |
try: | |
return generate_code(prompt, model, tokenizer, max_length, temperature) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=gradio_generate, | |
inputs=[ | |
gr.Textbox( | |
lines=5, | |
placeholder="Enter your code generation prompt here...", | |
label="Prompt" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=64, | |
maximum=2048, | |
value=512, | |
step=64, | |
label="Max Length" | |
) | |
], | |
outputs=gr.Code(label="Generated Code"), | |
title="Llama Code Generation with LoRA", | |
description="Enter a prompt to generate code using Llama 3.2 3B model fine-tuned with LoRA", | |
examples=[ | |
["Write a Python function to sort a list of numbers in ascending order"], | |
["Create a simple REST API using FastAPI that handles GET and POST requests"], | |
["Write a function to check if a string is a palindrome"] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |