visheratin commited on
Commit
c63ab68
1 Parent(s): 1d9c898

Update nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +3 -3
nllb_mrl.py CHANGED
@@ -16,12 +16,14 @@ class MatryoshkaNllbClipConfig(PretrainedConfig):
16
  clip_model_name: str = "",
17
  target_resolution: int = -1,
18
  mrl_resolutions: List[int] = [],
 
19
  **kwargs,
20
  ):
21
  super().__init__(**kwargs)
22
  self.clip_model_name = clip_model_name
23
  self.target_resolution = target_resolution
24
  self.mrl_resolutions = mrl_resolutions
 
25
 
26
 
27
  class MatryoshkaLayer(nn.Module):
@@ -54,10 +56,8 @@ class MatryoshkaNllbClip(PreTrainedModel):
54
  self.model = create_model(
55
  config.clip_model_name, output_dict=True
56
  )
57
- preprocess_cfg = get_pretrained_cfg(config.clip_model_name, "v1")
58
- pp_cfg = PreprocessCfg(preprocess_cfg)
59
  self.transform = image_transform_v2(
60
- pp_cfg,
61
  is_train=False,
62
  )
63
  self._device = device
 
16
  clip_model_name: str = "",
17
  target_resolution: int = -1,
18
  mrl_resolutions: List[int] = [],
19
+ preprocess_cfg: Union[PreprocessCfg, None] = None,
20
  **kwargs,
21
  ):
22
  super().__init__(**kwargs)
23
  self.clip_model_name = clip_model_name
24
  self.target_resolution = target_resolution
25
  self.mrl_resolutions = mrl_resolutions
26
+ self.preprocess_cfg = preprocess_cfg
27
 
28
 
29
  class MatryoshkaLayer(nn.Module):
 
56
  self.model = create_model(
57
  config.clip_model_name, output_dict=True
58
  )
 
 
59
  self.transform = image_transform_v2(
60
+ config.preprocess_cfg,
61
  is_train=False,
62
  )
63
  self._device = device