Spaces:
Running
Running
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_ingredient_vision.ipynb. | |
# %% auto 0 | |
__all__ = ['SAMPLE_IMG_DIR', 'format_image', 'BlipImageCaptioning', 'BlipVQA', 'VeganIngredientFinder'] | |
# %% ../nbs/03_ingredient_vision.ipynb 3 | |
import imghdr | |
import os | |
import time | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from PIL import Image | |
from transformers import ( | |
BlipForConditionalGeneration, | |
BlipForQuestionAnswering, | |
BlipProcessor, | |
pipeline, | |
) | |
import constants | |
# %% ../nbs/03_ingredient_vision.ipynb 7 | |
# fmt: off | |
def format_image( | |
image: str # Image file path | |
): | |
# fmt: on | |
img = Image.open(image) | |
width, height = img.size | |
ratio = min(512 / width, 512 / height) | |
width_new, height_new = (round(width * ratio), round(height * ratio)) | |
width_new = int(np.round(width_new / 64.0)) * 64 | |
height_new = int(np.round(height_new / 64.0)) * 64 | |
img = img.resize((width_new, height_new)) | |
img = img.convert("RGB") | |
return img | |
# %% ../nbs/03_ingredient_vision.ipynb 8 | |
class BlipImageCaptioning: | |
""" | |
Useful when you want to know what is inside the photo. | |
""" | |
# fmt: off | |
def __init__(self, | |
device: str | |
): # pytorch hardware identifier to run model on options: "cpu, cuda_0, cuda_1 ..., cuda_n" | |
# fmt: on | |
self.device = device | |
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
self.processor = BlipProcessor.from_pretrained( | |
"Salesforce/blip-image-captioning-base" | |
) | |
self.model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype | |
).to(self.device) | |
def inference(self, | |
image: Image | |
) -> str: # Caption for the image | |
inputs = self.processor(image, return_tensors="pt").to( | |
self.device, self.torch_dtype | |
) | |
out = self.model.generate(**inputs, max_new_tokens=50) | |
captions = self.processor.decode(out[0], skip_special_tokens=True) | |
return captions | |
# %% ../nbs/03_ingredient_vision.ipynb 10 | |
class BlipVQA: | |
# fmt: off | |
""" | |
BLIP Visual Question Answering | |
Useful when you need an answer for a question based on an image. | |
Examples: | |
what is the background color of this image, how many cats are in this figure, what is in this figure? | |
""" | |
# fmt: on | |
def __init__(self, device: str): | |
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
self.device = device | |
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
self.model = BlipForQuestionAnswering.from_pretrained( | |
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype | |
).to(self.device) | |
# fmt: off | |
def inference(self, | |
image: Image, | |
question: str | |
) -> str: # Answer to the query on the image | |
# fmt: on | |
image = image.convert("RGB") | |
inputs = self.processor(image, question, return_tensors="pt").to( | |
self.device, self.torch_dtype | |
) | |
out = self.model.generate(**inputs, max_new_tokens=100) | |
answer = self.processor.decode(out[0], skip_special_tokens=True) | |
return answer | |
# %% ../nbs/03_ingredient_vision.ipynb 12 | |
SAMPLE_IMG_DIR = Path(f"{constants.ROOT_DIR}/assets/images/vegan_ingredients") | |
# %% ../nbs/03_ingredient_vision.ipynb 19 | |
class VeganIngredientFinder: | |
def __init__(self): | |
self.vqa = BlipVQA("cpu") | |
# fmt: off | |
def list_ingredients(self, | |
img: str # Image file path | |
) -> str: | |
#fmt: on | |
img = format_image(img) | |
answer = self.vqa.inference( | |
img, f"What are three of the vegetables seen in the image if any?" | |
) | |
answer += "\n" + self.vqa.inference( | |
img, f"What are three of the fruits seen in the image if any?" | |
) | |
answer += "\n" + self.vqa.inference( | |
img, f"What grains and starches are in the image if any?" | |
) | |
if ( | |
"yes" | |
in self.vqa.inference( | |
img, f"Is there plant-based milk in the image?" | |
).lower() | |
): | |
answer += "\n" + "plant-based milk" | |
return answer | |