import gradio as gr import torch import uuid from feat_ext import VitLaionFeatureExtractor import shutil from queue import Queue, Full from utils import HFPetDatasetManager, load_enc_cls_model import os model_cls = None feat_extractor = None processor = None ds_manager = None HF_API_TOKEN = os.getenv('HF_API_TOKEN') ENC_KEY = os.getenv('ENC_KEY') dataset_name = os.getenv('DATASET_NAME') ds_manager_queue = Queue(maxsize=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def push_files_async(): try: ds_manager_queue.put_nowait('Ok') print('DS upload requested!') except Full: print('Pull already started!') def predict_diff(img_a, img_b): global model_cls, feat_extractor, processor x = processor(img_a).unsqueeze(dim=0).to(device), processor(img_b).unsqueeze(dim=0).to(device) a, b = feat_extractor(x) proba = torch.sigmoid(model_cls((a, b))).item() score_str = "{:.2f}".format(round(proba) * proba + round(1 - proba) * (1 - proba)) base_name = f"{str(uuid.uuid4()).replace('-', '')}-{score_str}" save_image_pairs(img_a, img_b, proba, base_name) return {'Same': proba, 'Different': 1 - proba}, base_name def save_image_pairs(img_a, img_b, proba, base_name): sub_dir = 'same' if proba > 0.5 else 'different' img_a.save(f'collected/normal/{sub_dir}/{base_name}_a.png') img_b.save(f'collected/normal/{sub_dir}/{base_name}_b.png') push_files_async() def move_to_flagged(base_name: str, label: str): sub_dir = label.lower() destination = f'collected/mistakes/{sub_dir}/' shutil.move(f'collected/normal/{sub_dir}/{base_name}_a.png', destination) shutil.move(f'collected/normal/{sub_dir}/{base_name}_b.png', destination) push_files_async() class PetFlaggingCallback(gr.FlaggingCallback): def setup(self, components, flagging_dir: str): pass def flag(self, flag_data, flag_option=None, flag_index=None, username=None): _, _, label, base_name = flag_data move_to_flagged(base_name, label['label']) demo = gr.Interface( title="Dog Recognition", description="Model that compares two images and identify if the belong to the same or different dog.", fn=predict_diff, inputs=[gr.Image(label="Image A", type="pil"), gr.Image(label="Image B", type="pil")], outputs=["label", gr.Text(visible=False)], flagging_callback=PetFlaggingCallback() ) if __name__ == "__main__": vit_model = torch.load('vit_model_complete.pt') vit_processor = torch.load('vit_processor_complete.pt') model_cls = load_enc_cls_model('model_scripted.pt_enc', ENC_KEY) feat_extractor = VitLaionFeatureExtractor(vit_model, vit_processor) processor = feat_extractor.transforms ds_manager = HFPetDatasetManager(dataset_name, hf_token=HF_API_TOKEN, queue=ds_manager_queue) ds_manager.daemon = True ds_manager.start() model_cls.to(device) feat_extractor.to(device) model_cls.eval() feat_extractor.eval() demo.queue() demo.launch()