Spaces:
Runtime error
Runtime error
import gzip | |
import io | |
import json | |
import random | |
import re | |
import tempfile | |
from typing import Dict, List, Optional | |
from PIL import Image | |
import requests | |
import streamlit as st | |
http_session = requests.Session() | |
API_URL = "https://world.openfoodfacts.org/api/v0" | |
PRODUCT_URL = API_URL + "/product" | |
OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products" | |
BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$") | |
def load_nn_data(url: str): | |
r = http_session.get(url) | |
with gzip.open(io.BytesIO(r.content), "rt") as f: | |
return {int(key): value for key, value in json.loads(f.read()).items()} | |
def load_logo_data(url: str): | |
r = http_session.get(url) | |
with gzip.open(io.BytesIO(r.content), "rt") as f: | |
return { | |
int(item["id"]): item for item in (json.loads(x) for x in map(str.strip, f)) | |
} | |
def get_image_from_url( | |
image_url: str, | |
error_raise: bool = False, | |
session: Optional[requests.Session] = None, | |
) -> Optional[Image.Image]: | |
if session: | |
r = http_session.get(image_url) | |
else: | |
r = requests.get(image_url) | |
if error_raise: | |
r.raise_for_status() | |
if r.status_code != 200: | |
return None | |
with tempfile.NamedTemporaryFile() as f: | |
f.write(r.content) | |
image = Image.open(f.name) | |
return image | |
def split_barcode(barcode: str) -> List[str]: | |
if not barcode.isdigit(): | |
raise ValueError("unknown barcode format: {}".format(barcode)) | |
match = BARCODE_PATH_REGEX.fullmatch(barcode) | |
if match: | |
return [x for x in match.groups() if x] | |
return [barcode] | |
def get_cropped_image(barcode: str, image_id: str, bounding_box): | |
image_path = generate_image_path(barcode, image_id) | |
url = OFF_IMAGE_BASE_URL + image_path | |
image = get_image_from_url(url, session=http_session) | |
if image is None: | |
return | |
ymin, xmin, ymax, xmax = bounding_box | |
(left, right, top, bottom) = ( | |
xmin * image.width, | |
xmax * image.width, | |
ymin * image.height, | |
ymax * image.height, | |
) | |
return image.crop((left, top, right, bottom)) | |
def generate_image_path(barcode: str, image_id: str) -> str: | |
splitted_barcode = split_barcode(barcode) | |
return "/{}/{}.jpg".format("/".join(splitted_barcode), image_id) | |
def display_predictions( | |
logo_data: Dict, | |
nn_data: Dict, | |
logo_id: Optional[int] = None, | |
): | |
if not logo_id: | |
logo_id = random.choice(list(nn_data.keys())) | |
st.write(f"Logo ID: {logo_id}") | |
logo = logo_data[logo_id] | |
logo_nn_data = nn_data[logo_id] | |
nn_ids = logo_nn_data["ids"] | |
nn_distances = logo_nn_data["distances"] | |
annotation = logo_nn_data["annotation"] | |
cropped_image = get_cropped_image( | |
logo["barcode"], logo["image_id"], logo["bounding_box"] | |
) | |
if cropped_image is None: | |
return | |
st.image(cropped_image, annotation, width=200) | |
cropped_images: List[Image.Image] = [] | |
captions: List[str] = [] | |
progress_bar = st.progress(0) | |
for i, (closest_id, distance) in enumerate(zip(nn_ids, nn_distances)): | |
progress_bar.progress((i + 1) / len(nn_ids)) | |
closest_logo = logo_data[closest_id] | |
cropped_image = get_cropped_image( | |
closest_logo["barcode"], | |
closest_logo["image_id"], | |
closest_logo["bounding_box"], | |
) | |
if cropped_image is None: | |
continue | |
if cropped_image.height > cropped_image.width: | |
cropped_image = cropped_image.rotate(90) | |
cropped_images.append(cropped_image) | |
captions.append(f"distance: {distance}") | |
if cropped_images: | |
st.image(cropped_images, captions, width=200) | |
st.sidebar.title("Logo Nearest Neighbors Demo") | |
st.sidebar.write( | |
"Get first 100 nearest neighbors for a random annotated logo.\n\n" | |
"CLIP model is used to generate embeddings, and nearest neighbors " | |
"are computed either using a brute-force approach or with ANN." | |
) | |
logo_id = st.sidebar.number_input("logo ID", step=1) or None | |
approximate = ( | |
st.sidebar.checkbox( | |
"ANN (HNSW)", | |
value=False, | |
help="Display approximate neighbors (instead of real " | |
"neighbors computed using brute-force approach", | |
) | |
or None | |
) | |
nn_data = load_nn_data( | |
f"https://static.openfoodfacts.org/data/logos/{'hnsw_50_closest_neighbours' if approximate else 'exact_100_neighbours'}.json.gz" | |
) | |
logo_data = load_logo_data( | |
"https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz" | |
) | |
if approximate: | |
st.write("Using approximate nearest neighbors method") | |
else: | |
st.write("Using exact (brute-force) nearest neighbors method") | |
display_predictions(logo_data=logo_data, nn_data=nn_data, logo_id=logo_id) | |