logo-clip-demo / app.py
Raphaël Bournhonesque
add approximate method
5c3d937
import gzip
import io
import json
import random
import re
import tempfile
from typing import Dict, List, Optional
from PIL import Image
import requests
import streamlit as st
http_session = requests.Session()
API_URL = "https://world.openfoodfacts.org/api/v0"
PRODUCT_URL = API_URL + "/product"
OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products"
BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$")
@st.cache(allow_output_mutation=True)
def load_nn_data(url: str):
r = http_session.get(url)
with gzip.open(io.BytesIO(r.content), "rt") as f:
return {int(key): value for key, value in json.loads(f.read()).items()}
@st.cache(allow_output_mutation=True)
def load_logo_data(url: str):
r = http_session.get(url)
with gzip.open(io.BytesIO(r.content), "rt") as f:
return {
int(item["id"]): item for item in (json.loads(x) for x in map(str.strip, f))
}
def get_image_from_url(
image_url: str,
error_raise: bool = False,
session: Optional[requests.Session] = None,
) -> Optional[Image.Image]:
if session:
r = http_session.get(image_url)
else:
r = requests.get(image_url)
if error_raise:
r.raise_for_status()
if r.status_code != 200:
return None
with tempfile.NamedTemporaryFile() as f:
f.write(r.content)
image = Image.open(f.name)
return image
def split_barcode(barcode: str) -> List[str]:
if not barcode.isdigit():
raise ValueError("unknown barcode format: {}".format(barcode))
match = BARCODE_PATH_REGEX.fullmatch(barcode)
if match:
return [x for x in match.groups() if x]
return [barcode]
def get_cropped_image(barcode: str, image_id: str, bounding_box):
image_path = generate_image_path(barcode, image_id)
url = OFF_IMAGE_BASE_URL + image_path
image = get_image_from_url(url, session=http_session)
if image is None:
return
ymin, xmin, ymax, xmax = bounding_box
(left, right, top, bottom) = (
xmin * image.width,
xmax * image.width,
ymin * image.height,
ymax * image.height,
)
return image.crop((left, top, right, bottom))
def generate_image_path(barcode: str, image_id: str) -> str:
splitted_barcode = split_barcode(barcode)
return "/{}/{}.jpg".format("/".join(splitted_barcode), image_id)
def display_predictions(
logo_data: Dict,
nn_data: Dict,
logo_id: Optional[int] = None,
):
if not logo_id:
logo_id = random.choice(list(nn_data.keys()))
st.write(f"Logo ID: {logo_id}")
logo = logo_data[logo_id]
logo_nn_data = nn_data[logo_id]
nn_ids = logo_nn_data["ids"]
nn_distances = logo_nn_data["distances"]
annotation = logo_nn_data["annotation"]
cropped_image = get_cropped_image(
logo["barcode"], logo["image_id"], logo["bounding_box"]
)
if cropped_image is None:
return
st.image(cropped_image, annotation, width=200)
cropped_images: List[Image.Image] = []
captions: List[str] = []
progress_bar = st.progress(0)
for i, (closest_id, distance) in enumerate(zip(nn_ids, nn_distances)):
progress_bar.progress((i + 1) / len(nn_ids))
closest_logo = logo_data[closest_id]
cropped_image = get_cropped_image(
closest_logo["barcode"],
closest_logo["image_id"],
closest_logo["bounding_box"],
)
if cropped_image is None:
continue
if cropped_image.height > cropped_image.width:
cropped_image = cropped_image.rotate(90)
cropped_images.append(cropped_image)
captions.append(f"distance: {distance}")
if cropped_images:
st.image(cropped_images, captions, width=200)
st.sidebar.title("Logo Nearest Neighbors Demo")
st.sidebar.write(
"Get first 100 nearest neighbors for a random annotated logo.\n\n"
"CLIP model is used to generate embeddings, and nearest neighbors "
"are computed either using a brute-force approach or with ANN."
)
logo_id = st.sidebar.number_input("logo ID", step=1) or None
approximate = (
st.sidebar.checkbox(
"ANN (HNSW)",
value=False,
help="Display approximate neighbors (instead of real "
"neighbors computed using brute-force approach",
)
or None
)
nn_data = load_nn_data(
f"https://static.openfoodfacts.org/data/logos/{'hnsw_50_closest_neighbours' if approximate else 'exact_100_neighbours'}.json.gz"
)
logo_data = load_logo_data(
"https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz"
)
if approximate:
st.write("Using approximate nearest neighbors method")
else:
st.write("Using exact (brute-force) nearest neighbors method")
display_predictions(logo_data=logo_data, nn_data=nn_data, logo_id=logo_id)