m3 commited on
Commit
8e8cbdc
1 Parent(s): e97054d

chore: add readme

Browse files
Files changed (3) hide show
  1. README.md +22 -0
  2. src/demo.py +17 -0
  3. src/model.py +47 -16
README.md CHANGED
@@ -3,3 +3,25 @@ license: apache-2.0
3
  ---
4
 
5
  refer: https://github.com/facebookresearch/sscd-copy-detection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
4
 
5
  refer: https://github.com/facebookresearch/sscd-copy-detection
6
+
7
+
8
+ ```python
9
+ # code in src/demo.py
10
+ import model
11
+ from transformers import pipeline
12
+ from transformers.image_utils import load_image
13
+
14
+ pipe = pipeline(
15
+ task='sscd-copy-detection',
16
+ model='m3/sscd-copy-detection',
17
+ batch_size=10,
18
+ device='cpu',
19
+ )
20
+
21
+ vec1 = pipe(load_image("http://images.cocodataset.org/val2017/000000039769.jpg"))
22
+ vec2 = pipe(load_image("http://images.cocodataset.org/val2017/000000039769.jpg"))
23
+
24
+ import torch.nn.functional as F
25
+ cos_sim = F.cosine_similarity(vec1, vec2, dim=0)
26
+ print('similarity:', cos_sim.item())
27
+ ```
src/demo.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model
2
+ from transformers import pipeline
3
+ from transformers.image_utils import load_image
4
+
5
+ pipe = pipeline(
6
+ task='sscd-copy-detection',
7
+ model='m3/sscd-copy-detection',
8
+ batch_size=10,
9
+ device='cpu',
10
+ )
11
+
12
+ vec1 = pipe(load_image("http://images.cocodataset.org/val2017/000000039769.jpg"))
13
+ vec2 = pipe(load_image("http://images.cocodataset.org/val2017/000000039769.jpg"))
14
+
15
+ import torch.nn.functional as F
16
+ cos_sim = F.cosine_similarity(vec1, vec2, dim=0)
17
+ print('similarity:', round(cos_sim.item(), 3))
src/model.py CHANGED
@@ -1,13 +1,15 @@
1
  from typing import List, Optional, Union
2
  from torchvision import transforms
3
  from PIL import Image
4
-
5
  from transformers.image_processing_utils import BaseImageProcessor
6
- from transformers import PreTrainedModel, PretrainedConfig
7
  import os
8
  from huggingface_hub import hf_hub_download
9
  import torch
10
  import torch.nn as nn
 
 
 
11
  class SscdImageProcessor(BaseImageProcessor):
12
  def __init__(
13
  self,
@@ -52,40 +54,69 @@ class SscdImageProcessor(BaseImageProcessor):
52
  image = image.convert('RGB')
53
  return preprocess(image).unsqueeze(0)
54
 
 
55
  class SscdConfig(PretrainedConfig):
56
  model_type = 'sscd-copy-detection'
 
57
  def __init__(self, model_path: str = None, **kwargs):
58
  if model_path is None:
59
  model_path = 'sscd_disc_mixup.torchscript.pt'
60
  super().__init__(model_path=model_path, **kwargs)
61
 
 
62
  class SscdModel(PreTrainedModel):
63
  config_class = SscdConfig
64
 
65
- def __init__(self, config):
66
  super().__init__(config)
67
  self.dummy_param = nn.Parameter(torch.zeros(0))
68
-
69
- print("______", config.name_or_path)
70
-
71
  is_local = os.path.isdir(config.name_or_path)
72
  if is_local:
73
  config.base_path = config.name_or_path
74
  else:
75
- config_path = hf_hub_download(repo_id=config.name_or_path, filename='config.json')
76
- config.base_path = os.path.dirname(config_path)
77
- model_path = config.base_path + '/' + config.model_path
78
- print("___model_path___", model_path)
 
 
 
 
 
79
 
80
  def forward(self, inputs):
81
- return self.model(inputs)
 
 
 
 
 
 
 
 
82
 
83
- sscd_processor = SscdImageProcessor()
84
- sscd_processor.save_pretrained('new_model')
85
- sscd_config = SscdConfig(model_path='sscd_disc_mixup.torchscript.pt')
86
- sscd_config.save_pretrained('new_model')
 
 
 
 
87
 
88
- model = SscdModel.from_pretrained('new_model')
 
89
 
90
 
 
 
 
 
91
 
 
 
 
 
 
 
1
  from typing import List, Optional, Union
2
  from torchvision import transforms
3
  from PIL import Image
 
4
  from transformers.image_processing_utils import BaseImageProcessor
5
+ from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoImageProcessor, AutoModel
6
  import os
7
  from huggingface_hub import hf_hub_download
8
  import torch
9
  import torch.nn as nn
10
+ from transformers.pipelines import PIPELINE_REGISTRY
11
+ from transformers.utils import add_end_docstrings
12
+ from transformers.pipelines.base import Pipeline, build_pipeline_init_args
13
  class SscdImageProcessor(BaseImageProcessor):
14
  def __init__(
15
  self,
 
54
  image = image.convert('RGB')
55
  return preprocess(image).unsqueeze(0)
56
 
57
+
58
  class SscdConfig(PretrainedConfig):
59
  model_type = 'sscd-copy-detection'
60
+
61
  def __init__(self, model_path: str = None, **kwargs):
62
  if model_path is None:
63
  model_path = 'sscd_disc_mixup.torchscript.pt'
64
  super().__init__(model_path=model_path, **kwargs)
65
 
66
+
67
  class SscdModel(PreTrainedModel):
68
  config_class = SscdConfig
69
 
70
+ def __init__(self, config, model_path: str = None):
71
  super().__init__(config)
72
  self.dummy_param = nn.Parameter(torch.zeros(0))
73
+ if model_path is None:
74
+ model_path = config.model_path
 
75
  is_local = os.path.isdir(config.name_or_path)
76
  if is_local:
77
  config.base_path = config.name_or_path
78
  else:
79
+ file_path = hf_hub_download(repo_id=config.name_or_path, filename=model_path)
80
+ config.base_path = os.path.dirname(file_path)
81
+ model_path = config.base_path + '/' + model_path
82
+ if model_path is not None:
83
+ self.model = torch.jit.load(model_path)
84
+
85
+ @classmethod
86
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
87
+ return cls(AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs))
88
 
89
  def forward(self, inputs):
90
+ return self.model(inputs)[0, :]
91
+
92
+
93
+
94
+ @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
95
+ class SscdPipeline(Pipeline):
96
+ def __init__(self, model, **kwargs):
97
+ self.device_id = kwargs['device']
98
+ super().__init__(model=model, **kwargs)
99
 
100
+ def _sanitize_parameters(self, **kwargs):
101
+ return {}, {}, {}
102
+
103
+ def preprocess(self, input):
104
+ return self.image_processor.preprocess(input)
105
+
106
+ def _forward(self, inputs):
107
+ return self.model(inputs)
108
 
109
+ def postprocess(self, model_outputs):
110
+ return model_outputs
111
 
112
 
113
+ AutoConfig.register('sscd-copy-detection', SscdConfig)
114
+ AutoModel.register(SscdConfig, SscdModel)
115
+ AutoImageProcessor.register(SscdConfig, slow_image_processor_class=SscdImageProcessor)
116
+ models = AutoModel.from_pretrained('m3/sscd-copy-detection')
117
 
118
+ PIPELINE_REGISTRY.register_pipeline(
119
+ task='sscd-copy-detection',
120
+ pipeline_class=SscdPipeline,
121
+ pt_model=SscdModel
122
+ )