File size: 1,886 Bytes
bd52cb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
from transformers import CLIPModel
from peft import LoraConfig, get_peft_model

class MLP(nn.Module):
    def __init__(self, input_dim=768, hidden_dim1=512, hidden_dim2=256, output_dim=8,dropout_rate=0.5):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.relu1 = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x
    
class clip_lora_model(nn.Module):
    def __init__(self, input_dim=768, hidden_dim1=512, hidden_dim2=256, output_dim=8,dropout_rate=0.5,r=16,lora_alpha=8):
        super(clip_lora_model, self).__init__()
        self.output_dim=output_dim
        self.mlp = MLP(input_dim, hidden_dim1, hidden_dim2, output_dim,dropout_rate)

        model_name = 'openai/clip-vit-large-patch14'
        model = CLIPModel.from_pretrained(model_name)
        self.proj = model.visual_projection 
        for param in self.proj.parameters():
            param.requires_grad = False
        encoder = model.vision_model
        target_modules = ["k_proj", "v_proj", "q_proj"]
        config = LoraConfig(
        r=int(r),
        lora_alpha=lora_alpha,
        target_modules=target_modules,
        lora_dropout=0.1,
        bias="none",
        )
        self.model = get_peft_model(encoder, config)
        
    def forward(self, x):
        model_outputs = self.model(x)
        image_embeds = model_outputs[1]
        model_outputs = self.proj(image_embeds)
        outputs = self.mlp(model_outputs)
        return outputs