florentgbelidji HF staff commited on
Commit
0546112
1 Parent(s): db4cf02

Modified handler to load BLIP directly from transformers

Browse files
Files changed (2) hide show
  1. handler.py +21 -26
  2. requirements.txt +1 -5
handler.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Dict, List, Any
2
  from PIL import Image
3
  import requests
@@ -5,32 +6,26 @@ import torch
5
  import base64
6
  import os
7
  from io import BytesIO
 
 
8
  from models.blip_decoder import blip_decoder
9
  from torchvision import transforms
10
  from torchvision.transforms.functional import InterpolationMode
 
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
  class EndpointHandler():
15
  def __init__(self, path=""):
16
  # load the optimized model
17
- self.model_path = os.path.join(path,'model_large_caption.pth')
18
- self.model = blip_decoder(
19
- pretrained=self.model_path,
20
- image_size=384,
21
- vit='large',
22
- med_config=os.path.join(path, 'configs/med_config.json')
23
- )
24
  self.model.eval()
25
  self.model = self.model.to(device)
26
 
27
- image_size = 384
28
- self.transform = transforms.Compose([
29
- transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
30
- transforms.ToTensor(),
31
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
32
- ])
33
-
34
 
35
 
36
  def __call__(self, data: Any) -> Dict[str, Any]:
@@ -39,22 +34,22 @@ class EndpointHandler():
39
  data (:obj:):
40
  includes the input data and the parameters for the inference.
41
  Return:
42
- A :obj:`dict`:. The object returned should be a dict of one list like {"caption": ["A hugging face at the office"]} containing :
43
  - "caption": A string corresponding to the generated caption.
44
  """
45
  inputs = data.pop("inputs", data)
46
  parameters = data.pop("parameters", {})
47
-
48
-
49
- image = Image.open(BytesIO(inputs))
50
- image = self.transform(image).unsqueeze(0).to(device)
 
 
 
51
  with torch.no_grad():
52
- caption = self.model.generate(
53
- image,
54
- sample=parameters.get('sample',True),
55
- top_p=parameters.get('top_p',0.9),
56
- max_length=parameters.get('max_length',20),
57
- min_length=parameters.get('min_length',5)
58
  )
 
59
  # postprocess the prediction
60
- return {"caption": caption}
 
1
+ # +
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  import requests
 
6
  import base64
7
  import os
8
  from io import BytesIO
9
+
10
+ from transformers import BlipForConditionalGeneration, BlipProcessor
11
  from models.blip_decoder import blip_decoder
12
  from torchvision import transforms
13
  from torchvision.transforms.functional import InterpolationMode
14
+ # -
15
 
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
  class EndpointHandler():
19
  def __init__(self, path=""):
20
  # load the optimized model
21
+
22
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
+ self.model = BlipForConditionalGeneration.from_pretrained(
24
+ "Salesforce/blip-image-captioning-base"
25
+ ).to(device)
 
 
26
  self.model.eval()
27
  self.model = self.model.to(device)
28
 
 
 
 
 
 
 
 
29
 
30
 
31
  def __call__(self, data: Any) -> Dict[str, Any]:
 
34
  data (:obj:):
35
  includes the input data and the parameters for the inference.
36
  Return:
37
+ A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing :
38
  - "caption": A string corresponding to the generated caption.
39
  """
40
  inputs = data.pop("inputs", data)
41
  parameters = data.pop("parameters", {})
42
+
43
+ raw_images = [Image.open(BytesIO(_img)) for _img in inputs]
44
+
45
+ processed_image = self.processor(images=raw_images, return_tensors="pt")
46
+ processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
47
+ processed_image = {**processed_image, **parameters}
48
+
49
  with torch.no_grad():
50
+ out = self.model.generate(
51
+ **processed_image
 
 
 
 
52
  )
53
+ captions = self.processor.batch_decode(out, skip_special_tokens=True)
54
  # postprocess the prediction
55
+ return {"captions": captions}
requirements.txt CHANGED
@@ -1,5 +1 @@
1
- timm==0.4.12
2
- transformers==4.15.0
3
- fairscale==0.4.4
4
- requests
5
- Pillow
 
1
+ git+https://github.com/huggingface/transformers.git@main