NotShrirang commited on
Commit
3783c34
·
1 Parent(s): 67e3cab

feat: add adapter model with self attention

Browse files
data_search/adapter_utils.py CHANGED
@@ -2,18 +2,30 @@ import torch
2
  import torch.nn as nn
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def get_adapter_model(in_shape, out_shape):
6
- model = nn.Sequential(
7
- nn.Linear(in_shape, 1024),
8
- nn.ReLU(),
9
- nn.Linear(1024, 1024),
10
- nn.ReLU(),
11
- nn.Linear(1024, out_shape)
12
- )
13
  return model
14
 
15
 
16
  def load_adapter_model():
17
  model = get_adapter_model(512, 384)
18
- model.load_state_dict(torch.load("./weights/adapter_model.pt", map_location=torch.device('cpu')))
19
  return model
 
2
  import torch.nn as nn
3
 
4
 
5
+ class Adapter(nn.Module):
6
+ def __init__(self, img_dim, txt_dim, embed_dim=1024, num_heads=8):
7
+ super().__init__()
8
+ self.adapter = nn.Sequential(
9
+ nn.Linear(img_dim, embed_dim),
10
+ nn.ReLU(),
11
+ nn.Linear(embed_dim, embed_dim),
12
+ nn.ReLU(),
13
+ nn.Linear(embed_dim, txt_dim)
14
+ )
15
+ self.self_attention = nn.MultiheadAttention(embed_dim=txt_dim, num_heads=num_heads)
16
+
17
+ def forward(self, img_emb):
18
+ img_emb = self.adapter(img_emb).unsqueeze(0)
19
+ attn_output, _= self.self_attention(img_emb, img_emb, img_emb)
20
+ return attn_output.squeeze(0)
21
+
22
+
23
  def get_adapter_model(in_shape, out_shape):
24
+ model = Adapter(in_shape, out_shape)
 
 
 
 
 
 
25
  return model
26
 
27
 
28
  def load_adapter_model():
29
  model = get_adapter_model(512, 384)
30
+ model.load_state_dict(torch.load("./weights/adapter_model_with_attention.pt", map_location=torch.device('cpu')))
31
  return model
data_upload/data_upload_page.py CHANGED
@@ -9,6 +9,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
 
10
  def data_upload(clip_model, preprocess, text_embedding_model):
11
  st.title("Data Upload")
 
12
  upload_choice = st.selectbox(options=["Upload Image", "Add Image from URL / Link", "Upload PDF", "Website Link"], label="Select Upload Type")
13
  if upload_choice == "Upload Image":
14
  image_util.upload_image(clip_model, preprocess)
 
9
 
10
  def data_upload(clip_model, preprocess, text_embedding_model):
11
  st.title("Data Upload")
12
+ st.warning("Please note that this is a public application. Make sure you are not uploading any sensitive data.")
13
  upload_choice = st.selectbox(options=["Upload Image", "Add Image from URL / Link", "Upload PDF", "Website Link"], label="Select Upload Type")
14
  if upload_choice == "Upload Image":
15
  image_util.upload_image(clip_model, preprocess)
weights/adapter_model_with_attention.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32785c4b2a93e21c358e57465021938c9c85e5cfc4bf477aff33afc72ad62cb2
3
+ size 10243596