sanjanatule commited on
Commit
d7298ca
1 Parent(s): ba218d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -31
app.py CHANGED
@@ -3,6 +3,7 @@ import peft
3
  from peft import LoraConfig
4
  from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
  import torch
 
6
 
7
  clip_model_name = "openai/clip-vit-base-patch32"
8
  phi_model_name = "microsoft/phi-2"
@@ -17,39 +18,11 @@ phi_embed = 2560
17
  # models
18
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
19
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
20
- bnb_config = BitsAndBytesConfig(
21
- load_in_4bit=True,
22
- bnb_4bit_quant_type="nf4",
23
- bnb_4bit_compute_dtype=torch.float16,)
24
-
25
- phi_model = AutoModelForCausalLM.from_pretrained(
26
- phi_model_name,
27
- torch_dtype=torch.float32,
28
- quantization_config=bnb_config,
29
- trust_remote_code=True
30
- )
31
- lora_alpha = 16
32
- lora_dropout = 0.1
33
- lora_r = 64
34
- peft_config = LoraConfig(
35
- lora_alpha=lora_alpha,
36
- lora_dropout=lora_dropout,
37
- r=lora_r,
38
- bias="none",
39
- task_type="CAUSAL_LM",
40
- target_modules=[
41
- "q_proj",
42
- 'k_proj',
43
- 'v_proj',
44
- 'fc1',
45
- 'fc2'
46
- ]
47
- )
48
- peft_model = peft.get_peft_model(phi_model, peft_config).to(device)
49
 
50
  # load weights
51
- model_to_merge = peft_model.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
52
- merged_model = model_to_merge.merge_and_unload()
53
  projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth'))
54
 
55
  def model_generate_ans(img,val_q):
 
3
  from peft import LoraConfig
4
  from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
  import torch
6
+ from peft import PeftModel
7
 
8
  clip_model_name = "openai/clip-vit-base-patch32"
9
  phi_model_name = "microsoft/phi-2"
 
18
  # models
19
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
20
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
21
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # load weights
24
+ model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
25
+ merged_model = model_to_merge.merge_and_unload()
26
  projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth'))
27
 
28
  def model_generate_ans(img,val_q):