Batching doesn't work with mask-generation task

#3
by abdullahalrafib - opened

Code snippet:

    import torch
    from transformers import SamProcessor, pipeline

    # transformers == 4.32.0
    # torch == 2.0.1
    # torchaudio == 2.0.2
    # torchvision == 0.15.2

    SAM_NAME = "facebook/sam-vit-base"
    processor = SamProcessor.from_pretrained(SAM_NAME)
    generator = pipeline(
        "mask-generation", model=SAM_NAME, device=torch.device("cuda"), batch_size=len(image_list)) # list of PIL images
    outputs = generator(image_list)

Another approach:

    SAM_NAME = "facebook/sam-vit-base"
    processor = SamProcessor.from_pretrained(SAM_NAME)
    generator = pipeline(
        "mask-generation", model=SAM_NAME, device=torch.device("cuda"))
    outputs = generator(image_list, batch_size=len(image_list))

Error traceback:

    outputs = generator(image_list)
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/pipelines/mask_generation.py", line 173, in __call__
    return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1110, in __call__
    outputs = list(final_iterator)
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/pipelines/pt_utils.py", line 125, in __next__
    processed = self.infer(item, **self.params)
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/pipelines/mask_generation.py", line 276, in postprocess
    output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/models/sam/image_processing_sam.py", line 554, in post_process_for_mask_generation
    return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/models/sam/image_processing_sam.py", line 1259, in _postprocess_for_mg
    masks = [_rle_to_mask(rle) for rle in rle_masks]
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/models/sam/image_processing_sam.py", line 1259, in <listcomp>
    masks = [_rle_to_mask(rle) for rle in rle_masks]
  File "/opt/anaconda3/envs/venv_yolo/lib/python3.8/site-packages/transformers/models/sam/image_processing_sam.py", line 1223, in _rle_to_mask
    height, width = rle["size"]
TypeError: string indices must be integers

Hi @abdullahalrafib
Thanks for the issue!
Can you try on a more recent version of transformers ? pip install -U transformers
The script below:

from PIL import Image
import requests
from transformers import pipeline

generator =  pipeline("mask-generation", device = 0, points_per_batch = 256)

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

outputs = generator([img_url, img_url], points_per_batch = 64)

outputs = generator([raw_image, raw_image], points_per_batch = 64)

Worked fine on my end, you could either pass a list of PIL images or a list of image URLs

Hi @ybelkada thank you for your response. I have tried with points_per_batch with both transformers==4.32.0 & transformers==4.34.1 it worked fine. But I couldn't find any documentation about it. Could you share any link for me please?

image.png

Also, I am getting following warning with average inference time 1.58s per image(which is a lot) for a list of 6 PIL type images.

/opt/anaconda3/envs/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset <transformers.pipelines.pt_utils.PipelineChunkIterator object at 0x7f6d4594ca60> was reported to be 6 (when accessing len(dataloader)), but 96 samples have been fetched. 
  warnings.warn(warn_msg)

Hi @abdullahalrafib
The documentation of that pipeline was missing indeed, please find here: https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.MaskGenerationPipeline the corresponding documentation

Sign up or log in to comment