radames commited on
Commit
c0ca149
1 Parent(s): 84a499f

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +24 -16
pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Any
2
  from PIL import Image
3
  import requests
4
  import torch
@@ -11,29 +11,30 @@ from torchvision.transforms.functional import InterpolationMode
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
 
14
  class PreTrainedPipeline():
15
  def __init__(self, path=""):
16
  # load the optimized model
17
- self.model_path = os.path.join(path,'model_large_retrieval_coco.pth')
18
  self.model = blip_feature_extractor(
19
- pretrained=self.model_path,
20
- image_size=384,
21
  vit='large',
22
  med_config=os.path.join(path, 'configs/med_config.json')
23
  )
24
  self.model.eval()
25
  self.model = self.model.to(device)
26
-
27
  image_size = 384
28
  self.transform = transforms.Compose([
29
- transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
 
30
  transforms.ToTensor(),
31
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
32
- ])
33
-
34
-
35
 
36
- def __call__(self, inputs: str) -> List[float]:
37
  """
38
  Args:
39
  data (:obj:):
@@ -43,11 +44,18 @@ class PreTrainedPipeline():
43
  - "feature_vector": A list of floats corresponding to the image embedding.
44
  """
45
  parameters = {"mode": "image"}
46
- # decode base64 image to PIL
47
- image = Image.open(BytesIO(base64.b64decode(inputs))).convert("RGB")
48
- image = self.transform(image).unsqueeze(0).to(device)
49
- text=""
 
 
 
 
 
 
50
  with torch.no_grad():
51
- feature_vector = self.model(image, text, mode=parameters["mode"])[0,0].tolist()
 
52
  # postprocess the prediction
53
  return feature_vector
 
1
+ from typing import Dict, List, Any, Union
2
  from PIL import Image
3
  import requests
4
  import torch
 
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
+
15
  class PreTrainedPipeline():
16
  def __init__(self, path=""):
17
  # load the optimized model
18
+ self.model_path = os.path.join(path, 'model_large_retrieval_coco.pth')
19
  self.model = blip_feature_extractor(
20
+ pretrained=self.model_path,
21
+ image_size=384,
22
  vit='large',
23
  med_config=os.path.join(path, 'configs/med_config.json')
24
  )
25
  self.model.eval()
26
  self.model = self.model.to(device)
27
+
28
  image_size = 384
29
  self.transform = transforms.Compose([
30
+ transforms.Resize((image_size, image_size),
31
+ interpolation=InterpolationMode.BICUBIC),
32
  transforms.ToTensor(),
33
+ transforms.Normalize(
34
+ (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
35
+ ])
 
36
 
37
+ def __call__(self, inputs: Union[str, "Image.Image"]) -> List[float]:
38
  """
39
  Args:
40
  data (:obj:):
 
44
  - "feature_vector": A list of floats corresponding to the image embedding.
45
  """
46
  parameters = {"mode": "image"}
47
+ if isinstance(inputs, str):
48
+ # decode base64 image to PIL
49
+ image = Image.open(
50
+ BytesIO(base64.b64decode(inputs))).convert("RGB")
51
+ elif isinstance(inputs, Image.Image):
52
+ image = inputs.convert("RGB")
53
+
54
+ image = self.transform(image).unsqueeze(0).to(device)
55
+
56
+ text = ""
57
  with torch.no_grad():
58
+ feature_vector = self.model(image, text, mode=parameters["mode"])[
59
+ 0, 0].tolist()
60
  # postprocess the prediction
61
  return feature_vector