EmTpro01's picture
Update app.py
8eb4dc2 verified
raw
history blame
2.47 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
def load_model_with_lora(base_model_name, lora_path):
"""
Load base model and merge it with LoRA adapter
"""
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Load and merge LoRA adapter
model = PeftModel.from_pretrained(base_model, lora_path)
model = model.merge_and_unload() # Merge adapter weights with base model
return model
def load_tokenizer(base_model_name):
"""
Load tokenizer for the base model
"""
return AutoTokenizer.from_pretrained(base_model_name)
def generate_code(prompt, model, tokenizer, max_length=512, temperature=0.7):
"""
Generate code based on the prompt
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Initialize model and tokenizer
BASE_MODEL_NAME = "unsloth/Llama-3.2-3B-bnb-4bit" # Replace with your base model name
LORA_PATH = "EmTpro01/Llama-3.2-3B-peft" # Replace with your LoRA adapter path
model = load_model_with_lora(BASE_MODEL_NAME, LORA_PATH)
tokenizer = load_tokenizer(BASE_MODEL_NAME)
# Create Gradio interface
def gradio_generate(prompt, temperature, max_length):
return generate_code(prompt, model, tokenizer, max_length, temperature)
demo = gr.Interface(
fn=gradio_generate,
inputs=[
gr.Textbox(
lines=5,
placeholder="Enter your code generation prompt here...",
label="Prompt"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=64,
maximum=2048,
value=512,
step=64,
label="Max Length"
)
],
outputs=gr.Code(language="python", label="Generated Code"),
title="Code Generation with LoRA",
description="Enter a prompt to generate code using a fine-tuned model with LoRA adapters",
)
if __name__ == "__main__":
demo.launch()