Spaces:
Running
Running
Commit
·
3783c34
1
Parent(s):
67e3cab
feat: add adapter model with self attention
Browse files
data_search/adapter_utils.py
CHANGED
@@ -2,18 +2,30 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def get_adapter_model(in_shape, out_shape):
|
6 |
-
model =
|
7 |
-
nn.Linear(in_shape, 1024),
|
8 |
-
nn.ReLU(),
|
9 |
-
nn.Linear(1024, 1024),
|
10 |
-
nn.ReLU(),
|
11 |
-
nn.Linear(1024, out_shape)
|
12 |
-
)
|
13 |
return model
|
14 |
|
15 |
|
16 |
def load_adapter_model():
|
17 |
model = get_adapter_model(512, 384)
|
18 |
-
model.load_state_dict(torch.load("./weights/
|
19 |
return model
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
|
5 |
+
class Adapter(nn.Module):
|
6 |
+
def __init__(self, img_dim, txt_dim, embed_dim=1024, num_heads=8):
|
7 |
+
super().__init__()
|
8 |
+
self.adapter = nn.Sequential(
|
9 |
+
nn.Linear(img_dim, embed_dim),
|
10 |
+
nn.ReLU(),
|
11 |
+
nn.Linear(embed_dim, embed_dim),
|
12 |
+
nn.ReLU(),
|
13 |
+
nn.Linear(embed_dim, txt_dim)
|
14 |
+
)
|
15 |
+
self.self_attention = nn.MultiheadAttention(embed_dim=txt_dim, num_heads=num_heads)
|
16 |
+
|
17 |
+
def forward(self, img_emb):
|
18 |
+
img_emb = self.adapter(img_emb).unsqueeze(0)
|
19 |
+
attn_output, _= self.self_attention(img_emb, img_emb, img_emb)
|
20 |
+
return attn_output.squeeze(0)
|
21 |
+
|
22 |
+
|
23 |
def get_adapter_model(in_shape, out_shape):
|
24 |
+
model = Adapter(in_shape, out_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
return model
|
26 |
|
27 |
|
28 |
def load_adapter_model():
|
29 |
model = get_adapter_model(512, 384)
|
30 |
+
model.load_state_dict(torch.load("./weights/adapter_model_with_attention.pt", map_location=torch.device('cpu')))
|
31 |
return model
|
data_upload/data_upload_page.py
CHANGED
@@ -9,6 +9,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
9 |
|
10 |
def data_upload(clip_model, preprocess, text_embedding_model):
|
11 |
st.title("Data Upload")
|
|
|
12 |
upload_choice = st.selectbox(options=["Upload Image", "Add Image from URL / Link", "Upload PDF", "Website Link"], label="Select Upload Type")
|
13 |
if upload_choice == "Upload Image":
|
14 |
image_util.upload_image(clip_model, preprocess)
|
|
|
9 |
|
10 |
def data_upload(clip_model, preprocess, text_embedding_model):
|
11 |
st.title("Data Upload")
|
12 |
+
st.warning("Please note that this is a public application. Make sure you are not uploading any sensitive data.")
|
13 |
upload_choice = st.selectbox(options=["Upload Image", "Add Image from URL / Link", "Upload PDF", "Website Link"], label="Select Upload Type")
|
14 |
if upload_choice == "Upload Image":
|
15 |
image_util.upload_image(clip_model, preprocess)
|
weights/adapter_model_with_attention.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:32785c4b2a93e21c358e57465021938c9c85e5cfc4bf477aff33afc72ad62cb2
|
3 |
+
size 10243596
|