File size: 3,034 Bytes
dcafc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608e54d
dcafc9b
 
 
 
 
 
 
 
 
 
 
 
 
cae220d
dcafc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d008b16
 
dcafc9b
d008b16
dcafc9b
 
 
 
 
 
 
 
 
608e54d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import torch
import uuid
from feat_ext import VitLaionFeatureExtractor
import shutil
from queue import Queue, Full
from utils import HFPetDatasetManager, load_enc_cls_model
import os

model_cls = None
feat_extractor = None
processor = None
ds_manager = None
HF_API_TOKEN = os.getenv('HF_API_TOKEN')
ENC_KEY = os.getenv('ENC_KEY')
dataset_name = os.getenv('DATASET_NAME')
ds_manager_queue = Queue(maxsize=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def push_files_async():
    try:
        ds_manager_queue.put_nowait('Ok')
        print('DS upload requested!')
    except Full:
        print('Pull already started!')


def predict_diff(img_a, img_b):
    global model_cls, feat_extractor, processor
    x = processor(img_a).unsqueeze(dim=0).to(device), processor(img_b).unsqueeze(dim=0).to(device)
    a, b = feat_extractor(x)
    proba = torch.sigmoid(model_cls((a, b))).item()
    score_str = "{:.2f}".format(round(proba) * proba + round(1 - proba) * (1 - proba))
    base_name = f"{str(uuid.uuid4()).replace('-', '')}-{score_str}"
    save_image_pairs(img_a, img_b, proba, base_name)
    return {'Same': proba, 'Different': 1 - proba}, base_name


def save_image_pairs(img_a, img_b, proba, base_name):
    sub_dir = 'same' if proba > 0.5 else 'different'
    img_a.save(f'collected/normal/{sub_dir}/{base_name}_a.png')
    img_b.save(f'collected/normal/{sub_dir}/{base_name}_b.png')
    push_files_async()


def move_to_flagged(base_name: str, label: str):
    sub_dir = label.lower()
    destination = f'collected/mistakes/{sub_dir}/'
    shutil.move(f'collected/normal/{sub_dir}/{base_name}_a.png', destination)
    shutil.move(f'collected/normal/{sub_dir}/{base_name}_b.png', destination)
    push_files_async()


class PetFlaggingCallback(gr.FlaggingCallback):

    def setup(self, components, flagging_dir: str):
        pass

    def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
        _, _, label, base_name = flag_data
        move_to_flagged(base_name, label['label'])


demo = gr.Interface(
    title="Dog Recognition",
    description="Model that compares two images and identify if the belong to the same or different dog.",
    fn=predict_diff,
    inputs=[gr.Image(label="Image A", type="pil"), gr.Image(label="Image B", type="pil")],
    outputs=["label", gr.Text(visible=False)],
    flagging_callback=PetFlaggingCallback()
)

if __name__ == "__main__":
    vit_model = torch.load('vit_model_complete.pt')
    vit_processor = torch.load('vit_processor_complete.pt')
    model_cls = load_enc_cls_model('model_scripted.pt_enc', ENC_KEY)
    feat_extractor = VitLaionFeatureExtractor(vit_model, vit_processor)
    processor = feat_extractor.transforms
    ds_manager = HFPetDatasetManager(dataset_name, hf_token=HF_API_TOKEN, queue=ds_manager_queue)
    ds_manager.daemon = True
    ds_manager.start()
    model_cls.to(device)
    feat_extractor.to(device)
    model_cls.eval()
    feat_extractor.eval()
    demo.queue()
    demo.launch()