visheratin commited on
Commit
d91b3a3
1 Parent(s): 21146f0

Update nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +10 -7
nllb_mrl.py CHANGED
@@ -3,8 +3,8 @@ from typing import List, Union
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from huggingface_hub import PyTorchModelHubMixin
7
- from open_clip import create_model_and_transforms, get_tokenizer
8
  from PIL import Image
9
  from transformers import PretrainedConfig, PreTrainedModel
10
 
@@ -13,14 +13,12 @@ class MatryoshkaNllbClipConfig(PretrainedConfig):
13
  def __init__(
14
  self,
15
  clip_model_name: str = "",
16
- clip_model_version: str = "",
17
  target_resolution: int = -1,
18
  mrl_resolutions: List[int] = [],
19
  **kwargs,
20
  ):
21
  super().__init__(**kwargs)
22
  self.clip_model_name = clip_model_name
23
- self.clip_model_version = clip_model_version
24
  self.target_resolution = target_resolution
25
  self.mrl_resolutions = mrl_resolutions
26
 
@@ -46,14 +44,19 @@ class MatryoshkaLayer(nn.Module):
46
 
47
  class MatryoshkaNllbClip(PreTrainedModel):
48
  config_class = MatryoshkaNllbClipConfig
49
-
50
  def __init__(self, config: MatryoshkaNllbClipConfig, device):
51
  super().__init__(config)
52
  if isinstance(device, str):
53
  device = torch.device(device)
54
  self.config = config
55
- self.model, _, self.transform = create_model_and_transforms(
56
- config.clip_model_name, config.clip_model_version, output_dict=True
 
 
 
 
 
57
  )
58
  self._device = device
59
  self.model.to(device)
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from open_clip import create_model, get_tokenizer
7
+ from open_clip.transform import PreprocessCfg, image_transform_v2
8
  from PIL import Image
9
  from transformers import PretrainedConfig, PreTrainedModel
10
 
 
13
  def __init__(
14
  self,
15
  clip_model_name: str = "",
 
16
  target_resolution: int = -1,
17
  mrl_resolutions: List[int] = [],
18
  **kwargs,
19
  ):
20
  super().__init__(**kwargs)
21
  self.clip_model_name = clip_model_name
 
22
  self.target_resolution = target_resolution
23
  self.mrl_resolutions = mrl_resolutions
24
 
 
44
 
45
  class MatryoshkaNllbClip(PreTrainedModel):
46
  config_class = MatryoshkaNllbClipConfig
47
+
48
  def __init__(self, config: MatryoshkaNllbClipConfig, device):
49
  super().__init__(config)
50
  if isinstance(device, str):
51
  device = torch.device(device)
52
  self.config = config
53
+ self.model = create_model(
54
+ config.clip_model_name, output_dict=True
55
+ )
56
+ pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg)
57
+ self.transform = image_transform_v2(
58
+ pp_cfg,
59
+ is_train=False,
60
  )
61
  self._device = device
62
  self.model.to(device)