Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import uuid | |
from feat_ext import VitLaionFeatureExtractor | |
import shutil | |
from queue import Queue, Full | |
from utils import HFPetDatasetManager, load_enc_cls_model | |
import os | |
model_cls = None | |
feat_extractor = None | |
processor = None | |
ds_manager = None | |
HF_API_TOKEN = os.getenv('HF_API_TOKEN') | |
ENC_KEY = os.getenv('ENC_KEY') | |
dataset_name = os.getenv('DATASET_NAME') | |
ds_manager_queue = Queue(maxsize=1) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def push_files_async(): | |
try: | |
ds_manager_queue.put_nowait('Ok') | |
print('DS upload requested!') | |
except Full: | |
print('Pull already started!') | |
def predict_diff(img_a, img_b): | |
global model_cls, feat_extractor, processor | |
x = processor(img_a).unsqueeze(dim=0).to(device), processor(img_b).unsqueeze(dim=0).to(device) | |
a, b = feat_extractor(x) | |
proba = torch.sigmoid(model_cls(a, b)).item() | |
score_str = "{:.2f}".format(round(proba) * proba + round(1 - proba) * (1 - proba)) | |
base_name = f"{str(uuid.uuid4()).replace('-', '')}-{score_str}" | |
save_image_pairs(img_a, img_b, proba, base_name) | |
return {'Same': proba, 'Different': 1 - proba}, base_name | |
def save_image_pairs(img_a, img_b, proba, base_name): | |
sub_dir = 'same' if proba > 0.5 else 'different' | |
img_a.save(f'collected/normal/{sub_dir}/{base_name}_a.png') | |
img_b.save(f'collected/normal/{sub_dir}/{base_name}_b.png') | |
push_files_async() | |
def move_to_flagged(base_name: str, label: str): | |
sub_dir = label.lower() | |
destination = f'collected/mistakes/{sub_dir}/' | |
shutil.move(f'collected/normal/{sub_dir}/{base_name}_a.png', destination) | |
shutil.move(f'collected/normal/{sub_dir}/{base_name}_b.png', destination) | |
push_files_async() | |
class PetFlaggingCallback(gr.FlaggingCallback): | |
def setup(self, components, flagging_dir: str): | |
pass | |
def flag(self, flag_data, flag_option=None, flag_index=None, username=None): | |
_, _, label, base_name = flag_data | |
move_to_flagged(base_name, label['label']) | |
demo = gr.Interface( | |
title="Dog Recognition", | |
description="Model that compares two images and identify if the belong to the same or different dog.", | |
fn=predict_diff, | |
inputs=[gr.Image(label="Image A", type="pil"), gr.Image(label="Image B", type="pil")], | |
outputs=["label", gr.Text(visible=False)], | |
flagging_callback=PetFlaggingCallback() | |
) | |
if __name__ == "__main__": | |
vit_model = torch.load('vit_model_complete.pt') | |
vit_processor = torch.load('vit_processor_complete.pt') | |
model_cls = load_enc_cls_model('model_scripted.pt_enc', ENC_KEY) | |
feat_extractor = VitLaionFeatureExtractor(vit_model, vit_processor) | |
processor = feat_extractor.transforms | |
ds_manager = HFPetDatasetManager(dataset_name, hf_token=HF_API_TOKEN, queue=ds_manager_queue) | |
ds_manager.daemon = True | |
ds_manager.start() | |
model_cls.to(device) | |
feat_extractor.to(device) | |
model_cls.eval() | |
feat_extractor.eval() | |
demo.queue() | |
demo.launch() | |