Mehdi Cherti commited on
Commit
2f12ebe
1 Parent(s): b6996c9

load clip encoder from path if possible

Browse files
Files changed (1) hide show
  1. clip_encoder.py +7 -1
clip_encoder.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import open_clip
4
  from einops import rearrange
5
-
6
 
7
  def exists(val):
8
  return val is not None
@@ -11,6 +11,12 @@ class CLIPEncoder(nn.Module):
11
 
12
  def __init__(self, model, pretrained):
13
  super().__init__()
 
 
 
 
 
 
14
  self.model = model
15
  self.pretrained = pretrained
16
  self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
 
2
  import torch.nn as nn
3
  import open_clip
4
  from einops import rearrange
5
+ import os
6
 
7
  def exists(val):
8
  return val is not None
 
11
 
12
  def __init__(self, model, pretrained):
13
  super().__init__()
14
+ #ViT_H_14_laion2b_s32b_b79k
15
+ fname = "models/" + model.replace("-", "_") + "_" + pretrained + ".pt"
16
+
17
+ if os.path.exists(fname):
18
+ print(fname)
19
+ pretrained = fname
20
  self.model = model
21
  self.pretrained = pretrained
22
  self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)