stupidog04 commited on
Commit
151724b
1 Parent(s): 45c38e5

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +27 -0
pipeline.py CHANGED
@@ -1,8 +1,35 @@
1
  from torchvision import transforms
 
2
  from transformers import ImageClassificationPipeline
3
  import torch
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class PairClassificationPipeline(ImageClassificationPipeline):
7
  pipe_to_tensor = transforms.ToTensor()
8
  pipe_to_pil = transforms.ToPILImage()
 
1
  from torchvision import transforms
2
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
3
  from transformers import ImageClassificationPipeline
4
  import torch
5
 
6
 
7
+ class PreTrainedPipeline():
8
+ def __init__(self, path):
9
+ """
10
+ Initialize model
11
+ """
12
+ # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
13
+ model_flag = 'google/vit-base-patch16-224-in21k'
14
+ # model_flag = 'google/vit-base-patch16-384'
15
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
16
+ self.model = ViTForImageClassification.from_pretrained(path)
17
+ self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor)
18
+
19
+ def __call__(self, inputs):
20
+ """
21
+ Args:
22
+ inputs (:obj:`np.array`):
23
+ The raw waveform of audio received. By default at 16KHz.
24
+ Return:
25
+ A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
26
+ the detected text from the input audio.
27
+ """
28
+ # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
29
+ # logits = self.model(input_values).logits.cpu().detach().numpy()[0]
30
+ return self.pipe(inputs)
31
+
32
+
33
  class PairClassificationPipeline(ImageClassificationPipeline):
34
  pipe_to_tensor = transforms.ToTensor()
35
  pipe_to_pil = transforms.ToPILImage()