Daniel Bustamante Ospina commited on
Commit
d008b16
1 Parent(s): dcafc9b

Changes in the model loading

Browse files
Files changed (5) hide show
  1. .gitignore +4 -0
  2. app.py +5 -3
  3. feat_ext.py +5 -6
  4. vit_model_complete.pt +3 -0
  5. vit_processor_complete.pt +3 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ collected/
2
+ __pycache__/
3
+ .idea/
4
+ tests.ipynb
app.py CHANGED
@@ -15,8 +15,8 @@ HF_API_TOKEN = os.getenv('HF_API_TOKEN')
15
  ENC_KEY = os.getenv('ENC_KEY')
16
  dataset_name = os.getenv('DATASET_NAME')
17
  ds_manager_queue = Queue(maxsize=1)
18
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
-
20
 
21
  def push_files_async():
22
  try:
@@ -72,8 +72,10 @@ demo = gr.Interface(
72
  )
73
 
74
  if __name__ == "__main__":
 
 
75
  model_cls = load_enc_cls_model('model_scripted.pt_enc', ENC_KEY)
76
- feat_extractor = VitLaionFeatureExtractor()
77
  processor = feat_extractor.transforms
78
  ds_manager = HFPetDatasetManager(dataset_name, hf_token=HF_API_TOKEN, queue=ds_manager_queue)
79
  ds_manager.daemon = True
 
15
  ENC_KEY = os.getenv('ENC_KEY')
16
  dataset_name = os.getenv('DATASET_NAME')
17
  ds_manager_queue = Queue(maxsize=1)
18
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ device = torch.device('cpu')
20
 
21
  def push_files_async():
22
  try:
 
72
  )
73
 
74
  if __name__ == "__main__":
75
+ vit_model = torch.load('vit_model_complete.pt')
76
+ vit_processor = torch.load('vit_processor_complete.pt')
77
  model_cls = load_enc_cls_model('model_scripted.pt_enc', ENC_KEY)
78
+ feat_extractor = VitLaionFeatureExtractor(vit_model, vit_processor)
79
  processor = feat_extractor.transforms
80
  ds_manager = HFPetDatasetManager(dataset_name, hf_token=HF_API_TOKEN, queue=ds_manager_queue)
81
  ds_manager.daemon = True
feat_ext.py CHANGED
@@ -1,12 +1,11 @@
1
  import torch
2
- from transformers import AutoModel, AutoProcessor
3
 
4
 
5
  class VitLaionPreProcess(torch.nn.Module):
6
 
7
- def __init__(self):
8
  super().__init__()
9
- self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
10
 
11
  def forward(self, img):
12
  out = self.processor(images=img, return_tensors="pt")
@@ -14,10 +13,10 @@ class VitLaionPreProcess(torch.nn.Module):
14
 
15
 
16
  class VitLaionFeatureExtractor(torch.nn.Module):
17
- def __init__(self):
18
  super().__init__()
19
- self.vit_model = AutoModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
20
- self.transforms = VitLaionPreProcess()
21
 
22
  def forward(self, x):
23
  img_a, img_b = x
 
1
  import torch
 
2
 
3
 
4
  class VitLaionPreProcess(torch.nn.Module):
5
 
6
+ def __init__(self, processor):
7
  super().__init__()
8
+ self.processor = processor
9
 
10
  def forward(self, img):
11
  out = self.processor(images=img, return_tensors="pt")
 
13
 
14
 
15
  class VitLaionFeatureExtractor(torch.nn.Module):
16
+ def __init__(self, model, processor):
17
  super().__init__()
18
+ self.vit_model = model
19
+ self.transforms = VitLaionPreProcess(processor)
20
 
21
  def forward(self, x):
22
  img_a, img_b = x
vit_model_complete.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ed3c7ad4dffd291c889f4536d2f9481e5d18cf17a8c86f028e4e028e4959997
3
+ size 10158847471
vit_processor_complete.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2e919db039f0c50ee555d843fc1195f8e985f80ef9a32658e02df12277bdb02
3
+ size 1526861