SONAR-Image / app.py
Sibgat-Ul's picture
Update app.py
a334df7 verified
import os
os.system("pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124 -q")
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
import requests
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
import torch.nn as nn
import torch.nn.functional as F
from io import BytesIO
from transformers.image_utils import load_image
cos = nn.CosineSimilarity()
model_path = hf_hub_download(
repo_id="Sibgat-Ul/SONAR-Image_enc",
filename="best_sonar.pth",
repo_type="model"
)
language_mapping = {
"English": "eng_Latn",
"Bengali": "ben_Beng",
"French": "fra_Latn"
}
# -------- Load Image Encoder --------
class SonarImageEnc(nn.Module):
def __init__(self, path="google/siglip2-base-patch16-384", initial_temperature=0.07):
super().__init__()
self.model = SiglipVisionModel.from_pretrained(path, torch_dtype=torch.float32)
for param in self.model.parameters():
param.requires_grad = False
self.projection = nn.Sequential(
nn.Linear(self.model.config.hidden_size, 2048),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(2048, 1024),
nn.LayerNorm(1024, eps=1e-5),
)
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0) / initial_temperature))
def forward(self, pixel_values):
with torch.no_grad():
vision_outputs = self.model(pixel_values=pixel_values)
pooled_output = vision_outputs.pooler_output
embeddings = self.projection(pooled_output)
self.logit_scale.data.clamp_(
min=torch.log(torch.tensor(1.0) / torch.tensor(0.001)),
max=torch.log(torch.tensor(1.0) / torch.tensor(100.0))
)
return embeddings, torch.exp(self.logit_scale)
# Load processor and models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = SiglipImageProcessor.from_pretrained("google/siglip2-base-patch16-384")
t2t_model_emb = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
dtype=torch.float16,
)
img_encoder = SonarImageEnc().to(device).eval()
img_encoder.load_state_dict(torch.load(model_path, map_location=device))
# -------- Similarity Scoring --------
def compute_similarity(
image, image_url,
option_a, option_b, option_c, option_d,
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
):
if not image:
try:
headers = {
"User-Agent": "Mozilla/5.0"
}
response = requests.get(image_url, headers=headers)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
return None, {"Error": f"Image could not be loaded: {str(e)}"}
# Preprocess image
inputs = processor(image, return_tensors="pt").to(device)
with torch.no_grad():
image_emb, _ = img_encoder(inputs.pixel_values)
image_emb = image_emb.to(device, torch.float16)
# Map languages
lang_codes = [
language_mapping[lang_opt_a],
language_mapping[lang_opt_b],
language_mapping[lang_opt_c],
language_mapping[lang_opt_d],
]
texts = [option_a, option_b, option_c, option_d]
# Get embeddings per option with corresponding language
text_embeddings = []
for text, lang in zip(texts, lang_codes):
emb = t2t_model_emb.predict([text], source_lang=lang)
text_embeddings.append(emb)
text_embeddings = torch.cat(text_embeddings, dim=0).to(device)
scores = cos(image_emb, text_embeddings)
results = {
f"Option {chr(65+i)}": round(score.item(), 3)
for i, score in enumerate(scores)
}
results = {
k: f"{round(v * 100, 2)}%"
for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
}
return image, results
# -------- Gradio UI --------
with gr.Blocks() as demo:
gr.Markdown("## πŸ” SONAR: Image-Text Similarity Scorer")
gr.Markdown("#### Upload an Image or provide an URL.")
with gr.Row():
with gr.Column():
image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg")
with gr.Row():
option_a = gr.Textbox(label="Option A", value="A cat with two remotes.")
lang_opt_a = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
option_b = gr.Textbox(label="Option B", value="Two cat with two remotes.")
lang_opt_b = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
option_c = gr.Textbox(label="Option C", value="Two remotes.")
lang_opt_c = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
option_d = gr.Textbox(label="Option D", value="Two cats.")
lang_opt_d = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language")
language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language")
with gr.Column():
image_input = gr.Image(label="Upload an image", type="pil")
btn = gr.Button("Done")
with gr.Row():
img_output = gr.Image(label="Input Image", type="pil", width=300, height=300)
result_output = gr.JSON(label="Similarity Scores")
btn.click(
fn=compute_similarity,
inputs=[
image_input, image_url,
option_a, option_b, option_c, option_d,
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d
],
outputs=[img_output, result_output]
)
demo.launch()