Spaces:
Running
Running
import os | |
os.system("pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124 -q") | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
import torch | |
import requests | |
from PIL import Image | |
from transformers import SiglipImageProcessor, SiglipVisionModel | |
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from io import BytesIO | |
from transformers.image_utils import load_image | |
cos = nn.CosineSimilarity() | |
model_path = hf_hub_download( | |
repo_id="Sibgat-Ul/SONAR-Image_enc", | |
filename="best_sonar.pth", | |
repo_type="model" | |
) | |
language_mapping = { | |
"English": "eng_Latn", | |
"Bengali": "ben_Beng", | |
"French": "fra_Latn" | |
} | |
# -------- Load Image Encoder -------- | |
class SonarImageEnc(nn.Module): | |
def __init__(self, path="google/siglip2-base-patch16-384", initial_temperature=0.07): | |
super().__init__() | |
self.model = SiglipVisionModel.from_pretrained(path, torch_dtype=torch.float32) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.projection = nn.Sequential( | |
nn.Linear(self.model.config.hidden_size, 2048), | |
nn.GELU(), | |
nn.Dropout(0.1), | |
nn.Linear(2048, 1024), | |
nn.LayerNorm(1024, eps=1e-5), | |
) | |
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1.0) / initial_temperature)) | |
def forward(self, pixel_values): | |
with torch.no_grad(): | |
vision_outputs = self.model(pixel_values=pixel_values) | |
pooled_output = vision_outputs.pooler_output | |
embeddings = self.projection(pooled_output) | |
self.logit_scale.data.clamp_( | |
min=torch.log(torch.tensor(1.0) / torch.tensor(0.001)), | |
max=torch.log(torch.tensor(1.0) / torch.tensor(100.0)) | |
) | |
return embeddings, torch.exp(self.logit_scale) | |
# Load processor and models | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
processor = SiglipImageProcessor.from_pretrained("google/siglip2-base-patch16-384") | |
t2t_model_emb = TextToEmbeddingModelPipeline( | |
encoder="text_sonar_basic_encoder", | |
tokenizer="text_sonar_basic_encoder", | |
device=device, | |
dtype=torch.float16, | |
) | |
img_encoder = SonarImageEnc().to(device).eval() | |
img_encoder.load_state_dict(torch.load(model_path, map_location=device)) | |
# -------- Similarity Scoring -------- | |
def compute_similarity( | |
image, image_url, | |
option_a, option_b, option_c, option_d, | |
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d | |
): | |
if not image: | |
try: | |
headers = { | |
"User-Agent": "Mozilla/5.0" | |
} | |
response = requests.get(image_url, headers=headers) | |
response.raise_for_status() | |
image = Image.open(BytesIO(response.content)).convert("RGB") | |
except Exception as e: | |
return None, {"Error": f"Image could not be loaded: {str(e)}"} | |
# Preprocess image | |
inputs = processor(image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
image_emb, _ = img_encoder(inputs.pixel_values) | |
image_emb = image_emb.to(device, torch.float16) | |
# Map languages | |
lang_codes = [ | |
language_mapping[lang_opt_a], | |
language_mapping[lang_opt_b], | |
language_mapping[lang_opt_c], | |
language_mapping[lang_opt_d], | |
] | |
texts = [option_a, option_b, option_c, option_d] | |
# Get embeddings per option with corresponding language | |
text_embeddings = [] | |
for text, lang in zip(texts, lang_codes): | |
emb = t2t_model_emb.predict([text], source_lang=lang) | |
text_embeddings.append(emb) | |
text_embeddings = torch.cat(text_embeddings, dim=0).to(device) | |
scores = cos(image_emb, text_embeddings) | |
results = { | |
f"Option {chr(65+i)}": round(score.item(), 3) | |
for i, score in enumerate(scores) | |
} | |
results = { | |
k: f"{round(v * 100, 2)}%" | |
for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True) | |
} | |
return image, results | |
# -------- Gradio UI -------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## π SONAR: Image-Text Similarity Scorer") | |
gr.Markdown("#### Upload an Image or provide an URL.") | |
with gr.Row(): | |
with gr.Column(): | |
image_url = gr.Textbox(label="Image URL", value="http://images.cocodataset.org/val2017/000000039769.jpg") | |
with gr.Row(): | |
option_a = gr.Textbox(label="Option A", value="A cat with two remotes.") | |
lang_opt_a = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
option_b = gr.Textbox(label="Option B", value="Two cat with two remotes.") | |
lang_opt_b = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
option_c = gr.Textbox(label="Option C", value="Two remotes.") | |
lang_opt_c = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
option_d = gr.Textbox(label="Option D", value="Two cats.") | |
lang_opt_d = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Language") | |
language = gr.Dropdown(choices=list(language_mapping.keys()), value="English", label="Select Language") | |
with gr.Column(): | |
image_input = gr.Image(label="Upload an image", type="pil") | |
btn = gr.Button("Done") | |
with gr.Row(): | |
img_output = gr.Image(label="Input Image", type="pil", width=300, height=300) | |
result_output = gr.JSON(label="Similarity Scores") | |
btn.click( | |
fn=compute_similarity, | |
inputs=[ | |
image_input, image_url, | |
option_a, option_b, option_c, option_d, | |
lang_opt_a, lang_opt_b, lang_opt_c, lang_opt_d | |
], | |
outputs=[img_output, result_output] | |
) | |
demo.launch() | |