Upload EVPDepth
Browse files- model.py +2 -2
- model.safetensors +1 -1
model.py
CHANGED
@@ -327,7 +327,7 @@ class EVPDepthEncoder(nn.Module):
|
|
327 |
param.requires_grad = True
|
328 |
|
329 |
self.text_adapter = TextAdapterRefer(text_dim=text_dim)
|
330 |
-
self.
|
331 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
332 |
|
333 |
if caption_aggregation:
|
@@ -398,7 +398,7 @@ class EVPDepthEncoder(nn.Module):
|
|
398 |
else:
|
399 |
class_embeddings = self.class_embeddings
|
400 |
|
401 |
-
c_crossattn = self.text_adapter(latents, class_embeddings, self.
|
402 |
t = torch.ones((x.shape[0],), device=x.device).long()
|
403 |
|
404 |
#if self.dataset == 'kitti':
|
|
|
327 |
param.requires_grad = True
|
328 |
|
329 |
self.text_adapter = TextAdapterRefer(text_dim=text_dim)
|
330 |
+
self.alpha = nn.Parameter(torch.ones(text_dim) * 1e-4)
|
331 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
332 |
|
333 |
if caption_aggregation:
|
|
|
398 |
else:
|
399 |
class_embeddings = self.class_embeddings
|
400 |
|
401 |
+
c_crossattn = self.text_adapter(latents, class_embeddings, self.alpha)
|
402 |
t = torch.ones((x.shape[0],), device=x.device).long()
|
403 |
|
404 |
#if self.dataset == 'kitti':
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3735516436
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83540a8d86d764c873c5d7e032d4c753751bde4165021cfe78b4478aa983af33
|
3 |
size 3735516436
|