EmTpro01 commited on
Commit
bd1c7d4
·
verified ·
1 Parent(s): 746a37e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -40
app.py CHANGED
@@ -1,58 +1,116 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from peft import PeftModel, PeftConfig
 
 
 
5
 
6
- def load_model_with_lora(base_model_name, lora_path):
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
- Load base model and merge it with LoRA adapter
9
  """
10
- # Load base model
11
- base_model = AutoModelForCausalLM.from_pretrained(
12
- base_model_name,
13
- torch_dtype=torch.float16,
14
- device_map="auto"
15
- )
16
-
17
- # Load and merge LoRA adapter
18
- model = PeftModel.from_pretrained(base_model, lora_path)
19
- model = model.merge_and_unload() # Merge adapter weights with base model
20
-
21
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def load_tokenizer(base_model_name):
24
  """
25
- Load tokenizer for the base model
26
  """
27
- return AutoTokenizer.from_pretrained(base_model_name)
 
 
 
 
 
 
28
 
29
  def generate_code(prompt, model, tokenizer, max_length=512, temperature=0.7):
30
  """
31
  Generate code based on the prompt
32
  """
33
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
34
-
35
- outputs = model.generate(
36
- **inputs,
37
- max_length=max_length,
38
- temperature=temperature,
39
- do_sample=True,
40
- pad_token_id=tokenizer.eos_token_id
41
- )
42
-
43
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Initialize model and tokenizer
46
- BASE_MODEL_NAME = "unsloth/Llama-3.2-3B-bnb-4bit" # Replace with your base model name
47
- LORA_PATH = "EmTpro01/Llama-3.2-3B-peft" # Replace with your LoRA adapter path
48
-
49
- model = load_model_with_lora(BASE_MODEL_NAME, LORA_PATH)
50
- tokenizer = load_tokenizer(BASE_MODEL_NAME)
51
 
52
- # Create Gradio interface
53
  def gradio_generate(prompt, temperature, max_length):
54
- return generate_code(prompt, model, tokenizer, max_length, temperature)
 
 
 
55
 
 
56
  demo = gr.Interface(
57
  fn=gradio_generate,
58
  inputs=[
@@ -76,9 +134,14 @@ demo = gr.Interface(
76
  label="Max Length"
77
  )
78
  ],
79
- outputs=gr.Code(language="python", label="Generated Code"),
80
- title="Code Generation with LoRA",
81
- description="Enter a prompt to generate code using a fine-tuned model with LoRA adapters",
 
 
 
 
 
82
  )
83
 
84
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ from peft import PeftModel
5
+ import logging
6
+ import os
7
+ from huggingface_hub import snapshot_download
8
 
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def download_lora_weights():
14
+ """Download LoRA weights from Hugging Face"""
15
+ return snapshot_download(
16
+ repo_id="EmTpro01/Llama-3.2-3B-peft",
17
+ allow_patterns=["adapter_config.json", "adapter_model.bin"],
18
+ )
19
+
20
+ def load_model_with_lora():
21
  """
22
+ Load Llama model and merge it with LoRA adapter
23
  """
24
+ try:
25
+ # Configure quantization
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_compute_dtype=torch.float16
30
+ )
31
+
32
+ # Load base model
33
+ base_model = AutoModelForCausalLM.from_pretrained(
34
+ "unsloth/llama-3.2-3b-bnb-4bit",
35
+ quantization_config=bnb_config,
36
+ device_map="auto",
37
+ trust_remote_code=True
38
+ )
39
+ logger.info("Successfully loaded base model")
40
+
41
+ # Download and load LoRA adapter
42
+ lora_path = download_lora_weights()
43
+ logger.info(f"Downloaded LoRA weights to: {lora_path}")
44
+
45
+ # Load and merge LoRA adapter
46
+ model = PeftModel.from_pretrained(base_model, lora_path)
47
+ logger.info("Successfully loaded LoRA adapter")
48
+
49
+ # For inference, we can merge the LoRA weights with the base model
50
+ model = model.merge_and_unload()
51
+ logger.info("Successfully merged LoRA weights with base model")
52
+
53
+ return model
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error loading model: {str(e)}")
57
+ raise RuntimeError(f"Failed to load model: {str(e)}")
58
 
59
+ def load_tokenizer():
60
  """
61
+ Load tokenizer for the Llama model
62
  """
63
+ try:
64
+ tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3.2-3b-bnb-4bit")
65
+ logger.info("Successfully loaded tokenizer")
66
+ return tokenizer
67
+ except Exception as e:
68
+ logger.error(f"Error loading tokenizer: {str(e)}")
69
+ raise RuntimeError(f"Failed to load tokenizer: {str(e)}")
70
 
71
  def generate_code(prompt, model, tokenizer, max_length=512, temperature=0.7):
72
  """
73
  Generate code based on the prompt
74
  """
75
+ try:
76
+ # Add any specific prompt template if needed
77
+ formatted_prompt = f"### Instruction: Write code for the following task:\n{prompt}\n\n### Response:"
78
+
79
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
80
+
81
+ outputs = model.generate(
82
+ **inputs,
83
+ max_length=max_length,
84
+ temperature=temperature,
85
+ do_sample=True,
86
+ top_p=0.95,
87
+ top_k=50,
88
+ repetition_penalty=1.1,
89
+ pad_token_id=tokenizer.eos_token_id
90
+ )
91
+
92
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
+ # Extract only the response part
94
+ response = generated_text.split("### Response:")[-1].strip()
95
+ return response
96
+ except Exception as e:
97
+ logger.error(f"Error during code generation: {str(e)}")
98
+ return f"Error generating code: {str(e)}"
99
 
100
  # Initialize model and tokenizer
101
+ logger.info("Starting model initialization...")
102
+ model = load_model_with_lora()
103
+ tokenizer = load_tokenizer()
104
+ logger.info("Model initialization completed successfully")
 
105
 
106
+ # Create Gradio interface with error handling
107
  def gradio_generate(prompt, temperature, max_length):
108
+ try:
109
+ return generate_code(prompt, model, tokenizer, max_length, temperature)
110
+ except Exception as e:
111
+ return f"Error: {str(e)}"
112
 
113
+ # Create the Gradio interface
114
  demo = gr.Interface(
115
  fn=gradio_generate,
116
  inputs=[
 
134
  label="Max Length"
135
  )
136
  ],
137
+ outputs=gr.Code(label="Generated Code"),
138
+ title="Llama Code Generation with LoRA",
139
+ description="Enter a prompt to generate code using Llama 3.2 3B model fine-tuned with LoRA",
140
+ examples=[
141
+ ["Write a Python function to sort a list of numbers in ascending order"],
142
+ ["Create a simple REST API using FastAPI that handles GET and POST requests"],
143
+ ["Write a function to check if a string is a palindrome"]
144
+ ]
145
  )
146
 
147
  if __name__ == "__main__":