florentgbelidji HF staff commited on
Commit
2a79ef4
1 Parent(s): acef01f

Creating captioning pipeline with nucleus sampling

Browse files
Files changed (1) hide show
  1. pipeline.py +18 -13
pipeline.py CHANGED
@@ -2,21 +2,28 @@ from typing import Dict, List, Any
2
  from PIL import Image
3
  import requests
4
  import torch
 
5
  from torchvision import transforms
6
  from torchvision.transforms.functional import InterpolationMode
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
- from transformers import pipeline, AutoTokenizer
11
-
12
 
13
  class PreTrainedPipeline():
14
  def __init__(self, path=""):
15
  # load the optimized model
16
- model = ORTModelForSequenceClassification.from_pretrained(path)
17
- tokenizer = AutoTokenizer.from_pretrained(path)
18
- # create inference pipeline
19
- self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
20
 
21
 
22
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
@@ -29,13 +36,11 @@ class PreTrainedPipeline():
29
  - "label": A string representing what the label/class is. There can be multiple labels.
30
  - "score": A score between 0 and 1 describing how confident the model is for this label/class.
31
  """
32
- inputs = data.pop("inputs", data)
33
  parameters = data.pop("parameters", None)
34
 
35
- # pass inputs with all kwargs in data
36
- if parameters is not None:
37
- prediction = self.pipeline(inputs, **parameters)
38
- else:
39
- prediction = self.pipeline(inputs)
40
  # postprocess the prediction
41
- return prediction
 
2
  from PIL import Image
3
  import requests
4
  import torch
5
+ from blip import blip_decoder
6
  from torchvision import transforms
7
  from torchvision.transforms.functional import InterpolationMode
8
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
 
 
11
 
12
  class PreTrainedPipeline():
13
  def __init__(self, path=""):
14
  # load the optimized model
15
+ self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
16
+ self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large')
17
+ self.model.eval()
18
+ self.model = model.to(device)
19
+
20
+ image_size = 384
21
+ self.transform = transforms.Compose([
22
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
25
+ ])
26
+
27
 
28
 
29
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
 
36
  - "label": A string representing what the label/class is. There can be multiple labels.
37
  - "score": A score between 0 and 1 describing how confident the model is for this label/class.
38
  """
39
+ image = data.pop("inputs", data)
40
  parameters = data.pop("parameters", None)
41
 
42
+ image = transform(image).unsqueeze(0).to(device)
43
+ with torch.no_grad():
44
+ caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
 
 
45
  # postprocess the prediction
46
+ return caption