EmTpro01's picture
Update app.py
bd1c7d4 verified
raw
history blame
4.83 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import logging
import os
from huggingface_hub import snapshot_download
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def download_lora_weights():
"""Download LoRA weights from Hugging Face"""
return snapshot_download(
repo_id="EmTpro01/Llama-3.2-3B-peft",
allow_patterns=["adapter_config.json", "adapter_model.bin"],
)
def load_model_with_lora():
"""
Load Llama model and merge it with LoRA adapter
"""
try:
# Configure quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
"unsloth/llama-3.2-3b-bnb-4bit",
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
logger.info("Successfully loaded base model")
# Download and load LoRA adapter
lora_path = download_lora_weights()
logger.info(f"Downloaded LoRA weights to: {lora_path}")
# Load and merge LoRA adapter
model = PeftModel.from_pretrained(base_model, lora_path)
logger.info("Successfully loaded LoRA adapter")
# For inference, we can merge the LoRA weights with the base model
model = model.merge_and_unload()
logger.info("Successfully merged LoRA weights with base model")
return model
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise RuntimeError(f"Failed to load model: {str(e)}")
def load_tokenizer():
"""
Load tokenizer for the Llama model
"""
try:
tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3.2-3b-bnb-4bit")
logger.info("Successfully loaded tokenizer")
return tokenizer
except Exception as e:
logger.error(f"Error loading tokenizer: {str(e)}")
raise RuntimeError(f"Failed to load tokenizer: {str(e)}")
def generate_code(prompt, model, tokenizer, max_length=512, temperature=0.7):
"""
Generate code based on the prompt
"""
try:
# Add any specific prompt template if needed
formatted_prompt = f"### Instruction: Write code for the following task:\n{prompt}\n\n### Response:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_p=0.95,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the response part
response = generated_text.split("### Response:")[-1].strip()
return response
except Exception as e:
logger.error(f"Error during code generation: {str(e)}")
return f"Error generating code: {str(e)}"
# Initialize model and tokenizer
logger.info("Starting model initialization...")
model = load_model_with_lora()
tokenizer = load_tokenizer()
logger.info("Model initialization completed successfully")
# Create Gradio interface with error handling
def gradio_generate(prompt, temperature, max_length):
try:
return generate_code(prompt, model, tokenizer, max_length, temperature)
except Exception as e:
return f"Error: {str(e)}"
# Create the Gradio interface
demo = gr.Interface(
fn=gradio_generate,
inputs=[
gr.Textbox(
lines=5,
placeholder="Enter your code generation prompt here...",
label="Prompt"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=64,
maximum=2048,
value=512,
step=64,
label="Max Length"
)
],
outputs=gr.Code(label="Generated Code"),
title="Llama Code Generation with LoRA",
description="Enter a prompt to generate code using Llama 3.2 3B model fine-tuned with LoRA",
examples=[
["Write a Python function to sort a list of numbers in ascending order"],
["Create a simple REST API using FastAPI that handles GET and POST requests"],
["Write a function to check if a string is a palindrome"]
]
)
if __name__ == "__main__":
demo.launch()