chatbot / lv_recipe_chatbot /ingredient_vision.py
Evan Lesmez
Cleanup notebooks to be nbdev_test friendly
5f3a430
raw
history blame
No virus
4.29 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_ingredient_vision.ipynb.
# %% auto 0
__all__ = ['SAMPLE_IMG_DIR', 'format_image', 'BlipImageCaptioning', 'BlipVQA', 'VeganIngredientFinder']
# %% ../nbs/03_ingredient_vision.ipynb 3
import imghdr
import os
import time
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from transformers import (
BlipForConditionalGeneration,
BlipForQuestionAnswering,
BlipProcessor,
pipeline,
)
import constants
# %% ../nbs/03_ingredient_vision.ipynb 7
# fmt: off
def format_image(
image: str # Image file path
):
# fmt: on
img = Image.open(image)
width, height = img.size
ratio = min(512 / width, 512 / height)
width_new, height_new = (round(width * ratio), round(height * ratio))
width_new = int(np.round(width_new / 64.0)) * 64
height_new = int(np.round(height_new / 64.0)) * 64
img = img.resize((width_new, height_new))
img = img.convert("RGB")
return img
# %% ../nbs/03_ingredient_vision.ipynb 8
class BlipImageCaptioning:
"""
Useful when you want to know what is inside the photo.
"""
# fmt: off
def __init__(self,
device: str
): # pytorch hardware identifier to run model on options: "cpu, cuda_0, cuda_1 ..., cuda_n"
# fmt: on
self.device = device
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype
).to(self.device)
def inference(self,
image: Image
) -> str: # Caption for the image
inputs = self.processor(image, return_tensors="pt").to(
self.device, self.torch_dtype
)
out = self.model.generate(**inputs, max_new_tokens=50)
captions = self.processor.decode(out[0], skip_special_tokens=True)
return captions
# %% ../nbs/03_ingredient_vision.ipynb 10
class BlipVQA:
# fmt: off
"""
BLIP Visual Question Answering
Useful when you need an answer for a question based on an image.
Examples:
what is the background color of this image, how many cats are in this figure, what is in this figure?
"""
# fmt: on
def __init__(self, device: str):
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.device = device
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype
).to(self.device)
# fmt: off
def inference(self,
image: Image,
question: str
) -> str: # Answer to the query on the image
# fmt: on
image = image.convert("RGB")
inputs = self.processor(image, question, return_tensors="pt").to(
self.device, self.torch_dtype
)
out = self.model.generate(**inputs, max_new_tokens=100)
answer = self.processor.decode(out[0], skip_special_tokens=True)
return answer
# %% ../nbs/03_ingredient_vision.ipynb 12
SAMPLE_IMG_DIR = Path(f"{constants.ROOT_DIR}/assets/images/vegan_ingredients")
# %% ../nbs/03_ingredient_vision.ipynb 19
class VeganIngredientFinder:
def __init__(self):
self.vqa = BlipVQA("cpu")
# fmt: off
def list_ingredients(self,
img: str # Image file path
) -> str:
#fmt: on
img = format_image(img)
answer = self.vqa.inference(
img, f"What are three of the vegetables seen in the image if any?"
)
answer += "\n" + self.vqa.inference(
img, f"What are three of the fruits seen in the image if any?"
)
answer += "\n" + self.vqa.inference(
img, f"What grains and starches are in the image if any?"
)
if (
"yes"
in self.vqa.inference(
img, f"Is there plant-based milk in the image?"
).lower()
):
answer += "\n" + "plant-based milk"
return answer