File size: 4,833 Bytes
8eb4dc2
 
bd1c7d4
 
 
 
 
8eb4dc2
bd1c7d4
 
 
 
 
 
 
 
 
 
 
 
8eb4dc2
bd1c7d4
8eb4dc2
bd1c7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8eb4dc2
bd1c7d4
8eb4dc2
bd1c7d4
8eb4dc2
bd1c7d4
 
 
 
 
 
 
8eb4dc2
 
 
 
 
bd1c7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8eb4dc2
 
bd1c7d4
 
 
 
8eb4dc2
bd1c7d4
8eb4dc2
bd1c7d4
 
 
 
8eb4dc2
bd1c7d4
8eb4dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1c7d4
 
 
 
 
 
 
 
8eb4dc2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import logging
import os
from huggingface_hub import snapshot_download

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def download_lora_weights():
    """Download LoRA weights from Hugging Face"""
    return snapshot_download(
        repo_id="EmTpro01/Llama-3.2-3B-peft",
        allow_patterns=["adapter_config.json", "adapter_model.bin"],
    )

def load_model_with_lora():
    """
    Load Llama model and merge it with LoRA adapter
    """
    try:
        # Configure quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16
        )
        
        # Load base model
        base_model = AutoModelForCausalLM.from_pretrained(
            "unsloth/llama-3.2-3b-bnb-4bit",
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )
        logger.info("Successfully loaded base model")
        
        # Download and load LoRA adapter
        lora_path = download_lora_weights()
        logger.info(f"Downloaded LoRA weights to: {lora_path}")
        
        # Load and merge LoRA adapter
        model = PeftModel.from_pretrained(base_model, lora_path)
        logger.info("Successfully loaded LoRA adapter")
        
        # For inference, we can merge the LoRA weights with the base model
        model = model.merge_and_unload()
        logger.info("Successfully merged LoRA weights with base model")
        
        return model
        
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        raise RuntimeError(f"Failed to load model: {str(e)}")

def load_tokenizer():
    """
    Load tokenizer for the Llama model
    """
    try:
        tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3.2-3b-bnb-4bit")
        logger.info("Successfully loaded tokenizer")
        return tokenizer
    except Exception as e:
        logger.error(f"Error loading tokenizer: {str(e)}")
        raise RuntimeError(f"Failed to load tokenizer: {str(e)}")

def generate_code(prompt, model, tokenizer, max_length=512, temperature=0.7):
    """
    Generate code based on the prompt
    """
    try:
        # Add any specific prompt template if needed
        formatted_prompt = f"### Instruction: Write code for the following task:\n{prompt}\n\n### Response:"
        
        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
        
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            do_sample=True,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id
        )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract only the response part
        response = generated_text.split("### Response:")[-1].strip()
        return response
    except Exception as e:
        logger.error(f"Error during code generation: {str(e)}")
        return f"Error generating code: {str(e)}"

# Initialize model and tokenizer
logger.info("Starting model initialization...")
model = load_model_with_lora()
tokenizer = load_tokenizer()
logger.info("Model initialization completed successfully")

# Create Gradio interface with error handling
def gradio_generate(prompt, temperature, max_length):
    try:
        return generate_code(prompt, model, tokenizer, max_length, temperature)
    except Exception as e:
        return f"Error: {str(e)}"

# Create the Gradio interface
demo = gr.Interface(
    fn=gradio_generate,
    inputs=[
        gr.Textbox(
            lines=5,
            placeholder="Enter your code generation prompt here...",
            label="Prompt"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=64,
            maximum=2048,
            value=512,
            step=64,
            label="Max Length"
        )
    ],
    outputs=gr.Code(label="Generated Code"),
    title="Llama Code Generation with LoRA",
    description="Enter a prompt to generate code using Llama 3.2 3B model fine-tuned with LoRA",
    examples=[
        ["Write a Python function to sort a list of numbers in ascending order"],
        ["Create a simple REST API using FastAPI that handles GET and POST requests"],
        ["Write a function to check if a string is a palindrome"]
    ]
)

if __name__ == "__main__":
    demo.launch()