|
import collections |
|
import heapq |
|
import json |
|
import os |
|
import logging |
|
import faiss |
|
import requests |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from open_clip import create_model, get_tokenizer |
|
from torchvision import transforms |
|
from PIL import Image |
|
import io |
|
from pathlib import Path |
|
from huggingface_hub import hf_hub_download |
|
|
|
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" |
|
logging.basicConfig(level=logging.INFO, format=log_format) |
|
logger = logging.getLogger() |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
model_str = "hf-hub:imageomics/bioclip" |
|
tokenizer_str = "ViT-B-16" |
|
|
|
txt_emb_npy = hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='txt_emb_species.npy', repo_type="dataset") |
|
txt_names_json = "txt_emb_species.json" |
|
|
|
min_prob = 1e-9 |
|
k = 5 |
|
|
|
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
preprocess_img = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Resize((224, 224), antialias=True), |
|
transforms.Normalize( |
|
mean=(0.48145466, 0.4578275, 0.40821073), |
|
std=(0.26862954, 0.26130258, 0.27577711), |
|
), |
|
] |
|
) |
|
|
|
MIN_PROB = 1e-9 |
|
TOP_K_PREDICTIONS = 5 |
|
TOP_K_CANDIDATES = 250 |
|
TOP_N_SIMILAR = 22 |
|
SIMILARITY_BOOST = 0.2 |
|
VOTE_THRESHOLD = 3 |
|
SIMILARITY_THRESHOLD = 0.99 |
|
|
|
|
|
PHOTO_LOOKUP_PATH = f"./photo_lookup.json" |
|
SPECIES_LOOKUP_PATH = f"./species_lookup.json" |
|
|
|
theme = gr.themes.Base( |
|
primary_hue=gr.themes.colors.teal, |
|
secondary_hue=gr.themes.colors.blue, |
|
neutral_hue=gr.themes.colors.gray, |
|
text_size=gr.themes.sizes.text_lg, |
|
).set( |
|
button_primary_background_fill="#114A56", |
|
button_primary_background_fill_hover="#114A56", |
|
block_title_text_weight="600", |
|
block_label_text_weight="600", |
|
block_label_text_size="*text_md", |
|
) |
|
|
|
EXAMPLES_DIR = Path("examples") |
|
example_images = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg")) |
|
|
|
def indexed(lst, indices): |
|
return [lst[i] for i in indices] |
|
|
|
def format_name(taxon, common): |
|
taxon = " ".join(taxon) |
|
if not common: |
|
return taxon |
|
return f"{taxon} ({common})" |
|
|
|
def combine_duplicate_predictions(predictions): |
|
"""Combine predictions where one name is contained within another.""" |
|
combined = {} |
|
used = set() |
|
|
|
|
|
items = sorted(predictions.items(), key=lambda x: (-len(x[0]), -x[1])) |
|
|
|
for name1, prob1 in items: |
|
if name1 in used: |
|
continue |
|
|
|
total_prob = prob1 |
|
used.add(name1) |
|
|
|
|
|
for name2, prob2 in predictions.items(): |
|
if name2 in used: |
|
continue |
|
|
|
|
|
name1_lower = name1.lower() |
|
name2_lower = name2.lower() |
|
|
|
|
|
if name1_lower in name2_lower or name2_lower in name1_lower: |
|
total_prob += prob2 |
|
used.add(name2) |
|
|
|
combined[name1] = total_prob |
|
|
|
|
|
total = sum(combined.values()) |
|
return {k: v/total for k, v in combined.items()} |
|
|
|
@torch.no_grad() |
|
def open_domain_classification(img, rank: int, return_all=False): |
|
""" |
|
Predicts from the entire tree of life using RAG approach. |
|
""" |
|
logger.info(f"Starting open domain classification for rank: {rank}") |
|
img = preprocess_img(img).to(device) |
|
img_features = model.encode_image(img.unsqueeze(0)) |
|
img_features = F.normalize(img_features, dim=-1) |
|
|
|
|
|
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze() |
|
probs = F.softmax(logits, dim=0) |
|
|
|
|
|
species_votes, similar_images = get_similar_images_metadata(img_features, faiss_index, id_mapping, name_mapping) |
|
|
|
if rank + 1 == len(ranks): |
|
|
|
topk = probs.topk(TOP_K_CANDIDATES) |
|
predictions = { |
|
format_name(*txt_names[i]): prob.item() |
|
for i, prob in zip(topk.indices, topk.values) |
|
} |
|
|
|
|
|
augmented_predictions = predictions.copy() |
|
for pred_name in predictions: |
|
pred_name_lower = pred_name.lower() |
|
for voted_species, vote_count in species_votes.items(): |
|
if voted_species in pred_name_lower or pred_name_lower in voted_species: |
|
augmented_predictions[pred_name] += SIMILARITY_BOOST * vote_count |
|
elif vote_count >= VOTE_THRESHOLD: |
|
augmented_predictions[voted_species] = vote_count * SIMILARITY_BOOST |
|
|
|
|
|
sorted_predictions = dict(sorted( |
|
augmented_predictions.items(), |
|
key=lambda x: x[1], |
|
reverse=True |
|
)[:k]) |
|
|
|
|
|
total = sum(sorted_predictions.values()) |
|
sorted_predictions = {k: v/total for k, v in sorted_predictions.items()} |
|
sorted_predictions = combine_duplicate_predictions(sorted_predictions) |
|
|
|
logger.info(f"Top K predictions after combining duplicates: {sorted_predictions}") |
|
return sorted_predictions, similar_images |
|
|
|
|
|
output = collections.defaultdict(float) |
|
for i in torch.nonzero(probs > MIN_PROB).squeeze(): |
|
output[" ".join(txt_names[i][0][: rank + 1])] += probs[i] |
|
|
|
|
|
for species, vote_count in species_votes.items(): |
|
try: |
|
|
|
for taxonomy, _ in txt_names: |
|
if species in " ".join(taxonomy).lower(): |
|
higher_rank = " ".join(taxonomy[: rank + 1]) |
|
output[higher_rank] += SIMILARITY_BOOST * vote_count |
|
break |
|
except Exception as e: |
|
logger.error(f"Error processing vote for species {species}: {e}") |
|
|
|
|
|
topk_names = heapq.nlargest(k, output, key=output.get) |
|
prediction_dict = {name: output[name] for name in topk_names} |
|
|
|
|
|
total = sum(prediction_dict.values()) |
|
prediction_dict = {k: v/total for k, v in prediction_dict.items()} |
|
prediction_dict = combine_duplicate_predictions(prediction_dict) |
|
|
|
logger.info(f"Prediction dictionary after combining duplicates: {prediction_dict}") |
|
|
|
return prediction_dict, similar_images |
|
|
|
|
|
def change_output(choice): |
|
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None) |
|
|
|
def get_cache_paths(name="demo"): |
|
"""Get paths for cached FAISS index and ID mapping.""" |
|
return { |
|
'index': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo.index', repo_type="dataset"), |
|
'mapping': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo_mapping.json', repo_type="dataset") |
|
} |
|
|
|
def build_name_mapping(txt_names): |
|
"""Build mapping between scientific names and common names.""" |
|
name_mapping = {} |
|
for taxonomy, common_name in txt_names: |
|
if not common_name: |
|
continue |
|
if len(taxonomy) >= 2: |
|
scientific_name = f"{taxonomy[-2]} {taxonomy[-1]}".lower() |
|
common_name = common_name.lower() |
|
name_mapping[scientific_name] = (scientific_name, common_name) |
|
name_mapping[common_name] = (scientific_name, common_name) |
|
return name_mapping |
|
|
|
def load_faiss_index(): |
|
"""Load FAISS index from cache.""" |
|
cache_paths = get_cache_paths() |
|
logger.info("Loading FAISS index from cache...") |
|
index = faiss.read_index(cache_paths['index']) |
|
with open(cache_paths['mapping'], 'r') as f: |
|
id_mapping = json.load(f) |
|
return index, id_mapping |
|
|
|
def get_similar_images_metadata(img_embedding, faiss_index, id_mapping, name_mapping): |
|
"""Get metadata for similar images using FAISS search.""" |
|
img_embedding_np = img_embedding.cpu().numpy() |
|
if img_embedding_np.ndim == 1: |
|
img_embedding_np = img_embedding_np.reshape(1, -1) |
|
|
|
|
|
distances, indices = faiss_index.search(img_embedding_np, TOP_N_SIMILAR * 2) |
|
|
|
|
|
valid_indices = [] |
|
valid_distances = [] |
|
valid_count = 0 |
|
|
|
for dist, idx in zip(distances[0], indices[0]): |
|
|
|
similarity = dist |
|
if similarity > SIMILARITY_THRESHOLD: |
|
continue |
|
|
|
valid_indices.append(idx) |
|
valid_distances.append(similarity) |
|
valid_count += 1 |
|
|
|
if valid_count >= TOP_N_SIMILAR: |
|
break |
|
|
|
species_votes = {} |
|
similar_images = [] |
|
|
|
for idx, similarity in zip(valid_indices[:5], valid_distances[:5]): |
|
similar_img_id = id_mapping[idx] |
|
|
|
try: |
|
species_names = id_to_species_info.get(similar_img_id) |
|
species_names = [name for name in species_names if name] |
|
|
|
processed_names = set() |
|
for species in species_names: |
|
if not species: |
|
continue |
|
name_tuple = name_mapping.get(species) |
|
if name_tuple: |
|
processed_names.add(name_tuple[0]) |
|
else: |
|
processed_names.add(species) |
|
|
|
for species in processed_names: |
|
species_votes[species] = species_votes.get(species, 0) + 1 |
|
|
|
|
|
|
|
similar_images.append({ |
|
'id': similar_img_id, |
|
'species': next(iter(processed_names)) if processed_names else 'Unknown', |
|
'common_name': species_names[-1], |
|
'similarity': similarity |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing JSON for image {similar_img_id}: {e}") |
|
continue |
|
|
|
return species_votes, similar_images |
|
|
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting.") |
|
model = create_model(model_str, output_dict=True, require_pretrained=True) |
|
model = model.to(device) |
|
logger.info("Created model.") |
|
|
|
model = torch.compile(model) |
|
logger.info("Compiled model.") |
|
|
|
tokenizer = get_tokenizer(tokenizer_str) |
|
|
|
id_to_photo_url = json.load(open(PHOTO_LOOKUP_PATH)) |
|
id_to_species_info = json.load(open(SPECIES_LOOKUP_PATH)) |
|
logger.info(f"Loaded {len(id_to_photo_url)} photo mappings") |
|
logger.info(f"Loaded {len(id_to_species_info)} species mappings") |
|
|
|
txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device) |
|
with open(txt_names_json) as fd: |
|
txt_names = json.load(fd) |
|
|
|
|
|
name_mapping = build_name_mapping(txt_names) |
|
|
|
|
|
faiss_index, id_mapping = load_faiss_index() |
|
|
|
|
|
def process_output(img, rank): |
|
predictions, similar_imgs = open_domain_classification(img, rank) |
|
|
|
logger.info(f"Number of similar images found: {len(similar_imgs)}") |
|
|
|
images = [] |
|
labels = [] |
|
|
|
for img_info in similar_imgs: |
|
img_id = img_info['id'] |
|
img_url = id_to_photo_url.get(img_id) |
|
img_url = img_url.replace("square", "small") |
|
logger.info(f"Processing image URL: {img_url}") |
|
|
|
try: |
|
|
|
response = requests.get(img_url) |
|
if response.status_code == 200: |
|
try: |
|
img = Image.open(io.BytesIO(response.content)) |
|
images.append(img) |
|
except Exception as e: |
|
logger.info(f"Failed to load image from URL: {e}") |
|
images.append(None) |
|
else: |
|
logger.info(f"Failed to fetch image from URL: {response}") |
|
images.append(None) |
|
|
|
|
|
label = f"**{img_info['species']}**" |
|
if img_info['common_name']: |
|
label += f" ({img_info['common_name']})" |
|
label += f"\nSimilarity: {img_info['similarity']:.3f}" |
|
label += f"\n[View on iNaturalist](https://www.inaturalist.org/observations/{img_id})" |
|
labels.append(label) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image {img_id}: {e}") |
|
images.append(None) |
|
labels.append("") |
|
|
|
|
|
images += [None] * (5 - len(images)) |
|
labels += [""] * (5 - len(labels)) |
|
|
|
logger.info(f"Final number of images: {len(images)}") |
|
logger.info(f"Final number of labels: {len(labels)}") |
|
|
|
return [predictions] + images + labels |
|
|
|
with gr.Blocks(theme=theme) as app: |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
gr.Image("image.jpg", elem_id="logo-img", |
|
show_label=False ) |
|
with gr.Column(scale=30): |
|
gr.Markdown("""Biome is a vision foundation model-powered tool customized to identify Singapore's local biodiversity. |
|
<br/> <br/> |
|
**Developed by**: Pye Sone Kyaw - AI Engineer @ Multimodal AI Team - AI Practice - GovTech SG |
|
<br/> <br/> |
|
Under the hood, Biome is using [BioCLIP](https://github.com/Imageomics/BioCLIP) augmented with multimodal search and retrieval to enhance its Singapore-specific biodiversity classification capabilities. |
|
<br/> <br/> |
|
Biome work best when the organism is clearly visible and takes up a substantial part of the image. |
|
""") |
|
|
|
with gr.Row(variant="panel", elem_id="images_panel"): |
|
img_input = gr.Image( |
|
height=400, |
|
sources=["upload"], |
|
type="pil" |
|
) |
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Examples( |
|
examples=example_images, |
|
inputs=img_input, |
|
label="Example Images" |
|
) |
|
rank_dropdown = gr.Dropdown( |
|
label="Taxonomic Rank", |
|
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.", |
|
choices=ranks, |
|
value="Species", |
|
type="index", |
|
) |
|
open_domain_btn = gr.Button("Submit", variant="primary") |
|
with gr.Column(): |
|
open_domain_output = gr.Label( |
|
num_top_classes=k, |
|
label="Prediction", |
|
show_label=True, |
|
value=None, |
|
) |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(): |
|
gr.Markdown("### Most Similar Images from Database") |
|
|
|
with gr.Row(): |
|
similar_images = [ |
|
gr.Image(label="Similar Image 1", height=200, show_label=True), |
|
gr.Image(label="Similar Image 2", height=200, show_label=True), |
|
gr.Image(label="Similar Image 3", height=200, show_label=True), |
|
gr.Image(label="Similar Image 4", height=200, show_label=True), |
|
gr.Image(label="Similar Image 5", height=200, show_label=True), |
|
] |
|
|
|
with gr.Row(): |
|
similar_labels = [ |
|
gr.Markdown("Species 1"), |
|
gr.Markdown("Species 2"), |
|
gr.Markdown("Species 3"), |
|
gr.Markdown("Species 4"), |
|
gr.Markdown("Species 5"), |
|
] |
|
|
|
rank_dropdown.change( |
|
fn=change_output, |
|
inputs=rank_dropdown, |
|
outputs=[open_domain_output] |
|
) |
|
|
|
open_domain_btn.click( |
|
fn=process_output, |
|
inputs=[img_input, rank_dropdown], |
|
outputs=[open_domain_output] + similar_images + similar_labels, |
|
) |
|
|
|
with gr.Row(variant="panel"): |
|
gr.Markdown(""" |
|
**Disclaimer**: This is a proof-of-concept demo for non-commercial purposes. No data is stored or used for any form of training, and all data used for retrieval are from [iNaturalist](https://inaturalist.org/). |
|
The adage of garbage in, garbage out applies here - uploading images not biodiversity-related will yield unpredictable results. |
|
""") |
|
app.queue(max_size=20) |
|
app.launch(share=False, enable_monitoring=False, allowed_paths=["/app/"]) |