Daniel Bustamante Ospina commited on
Commit
dcafc9b
1 Parent(s): f97e9e6

App for dog recognition (pet similarity)

Browse files
Files changed (6) hide show
  1. .idea/.gitignore +8 -0
  2. app.py +86 -0
  3. feat_ext.py +25 -0
  4. model_scripted.pt_enc +0 -0
  5. requirements.txt +3 -0
  6. utils.py +57 -0
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import uuid
4
+ from feat_ext import VitLaionFeatureExtractor
5
+ import shutil
6
+ from queue import Queue, Full
7
+ from utils import HFPetDatasetManager, load_enc_cls_model
8
+ import os
9
+
10
+ model_cls = None
11
+ feat_extractor = None
12
+ processor = None
13
+ ds_manager = None
14
+ HF_API_TOKEN = os.getenv('HF_API_TOKEN')
15
+ ENC_KEY = os.getenv('ENC_KEY')
16
+ dataset_name = os.getenv('DATASET_NAME')
17
+ ds_manager_queue = Queue(maxsize=1)
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+
21
+ def push_files_async():
22
+ try:
23
+ ds_manager_queue.put_nowait('Ok')
24
+ print('DS upload requested!')
25
+ except Full:
26
+ print('Pull already started!')
27
+
28
+
29
+ def predict_diff(img_a, img_b):
30
+ global model_cls, feat_extractor, processor
31
+ x = processor(img_a).unsqueeze(dim=0).to(device), processor(img_b).unsqueeze(dim=0).to(device)
32
+ a, b = feat_extractor(x)
33
+ proba = torch.sigmoid(model_cls(a, b)).item()
34
+ score_str = "{:.2f}".format(round(proba) * proba + round(1 - proba) * (1 - proba))
35
+ base_name = f"{str(uuid.uuid4()).replace('-', '')}-{score_str}"
36
+ save_image_pairs(img_a, img_b, proba, base_name)
37
+ return {'Same': proba, 'Different': 1 - proba}, base_name
38
+
39
+
40
+ def save_image_pairs(img_a, img_b, proba, base_name):
41
+ sub_dir = 'same' if proba > 0.5 else 'different'
42
+ img_a.save(f'collected/normal/{sub_dir}/{base_name}_a.png')
43
+ img_b.save(f'collected/normal/{sub_dir}/{base_name}_b.png')
44
+ push_files_async()
45
+
46
+
47
+ def move_to_flagged(base_name: str, label: str):
48
+ sub_dir = label.lower()
49
+ destination = f'collected/mistakes/{sub_dir}/'
50
+ shutil.move(f'collected/normal/{sub_dir}/{base_name}_a.png', destination)
51
+ shutil.move(f'collected/normal/{sub_dir}/{base_name}_b.png', destination)
52
+ push_files_async()
53
+
54
+
55
+ class PetFlaggingCallback(gr.FlaggingCallback):
56
+
57
+ def setup(self, components, flagging_dir: str):
58
+ pass
59
+
60
+ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
61
+ _, _, label, base_name = flag_data
62
+ move_to_flagged(base_name, label['label'])
63
+
64
+
65
+ demo = gr.Interface(
66
+ title="Dog Recognition",
67
+ description="Model that compares two images and identify if the belong to the same or different dog.",
68
+ fn=predict_diff,
69
+ inputs=[gr.Image(label="Image A", type="pil"), gr.Image(label="Image B", type="pil")],
70
+ outputs=["label", gr.Text(visible=False)],
71
+ flagging_callback=PetFlaggingCallback()
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ model_cls = load_enc_cls_model('model_scripted.pt_enc', ENC_KEY)
76
+ feat_extractor = VitLaionFeatureExtractor()
77
+ processor = feat_extractor.transforms
78
+ ds_manager = HFPetDatasetManager(dataset_name, hf_token=HF_API_TOKEN, queue=ds_manager_queue)
79
+ ds_manager.daemon = True
80
+ ds_manager.start()
81
+ model_cls.to(device)
82
+ feat_extractor.to(device)
83
+ model_cls.eval()
84
+ feat_extractor.eval()
85
+ demo.queue()
86
+ demo.launch(share=True)
feat_ext.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoProcessor
3
+
4
+
5
+ class VitLaionPreProcess(torch.nn.Module):
6
+
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
10
+
11
+ def forward(self, img):
12
+ out = self.processor(images=img, return_tensors="pt")
13
+ return out.data['pixel_values'].squeeze()
14
+
15
+
16
+ class VitLaionFeatureExtractor(torch.nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.vit_model = AutoModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
20
+ self.transforms = VitLaionPreProcess()
21
+
22
+ def forward(self, x):
23
+ img_a, img_b = x
24
+ return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features(
25
+ pixel_values=img_b)
model_scripted.pt_enc ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ cryptography
utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from threading import Thread
3
+ from cryptography.fernet import Fernet
4
+ import torch
5
+ import io
6
+
7
+
8
+ class HFPetDatasetManager(Thread):
9
+ def __init__(self, ds_name, hf_token, queue, local_path='collected'):
10
+ Thread.__init__(self)
11
+ self.queue = queue
12
+ import huggingface_hub
13
+ repo_id = huggingface_hub.get_full_repo_name(
14
+ ds_name, token=hf_token
15
+ )
16
+ self.path_to_dataset_repo = huggingface_hub.create_repo(
17
+ repo_id=repo_id,
18
+ token=hf_token,
19
+ private=True,
20
+ repo_type="dataset",
21
+ exist_ok=True,
22
+ )
23
+ self.repo = huggingface_hub.Repository(
24
+ local_dir=local_path,
25
+ clone_from=self.path_to_dataset_repo,
26
+ use_auth_token=hf_token,
27
+ )
28
+ self.repo.git_pull()
29
+ self.mistakes_dir = Path(local_path) / "mistakes"
30
+ self.normal_dir = Path(local_path) / "normal"
31
+
32
+ self.true_different_dir = self.normal_dir / "different"
33
+ self.true_same_dir = self.normal_dir / "same"
34
+
35
+ self.false_different_dir = self.mistakes_dir / "different"
36
+ self.false_same_dir = self.mistakes_dir / "same"
37
+
38
+ self.true_same_dir.mkdir(parents=True, exist_ok=True)
39
+ self.true_different_dir.mkdir(parents=True, exist_ok=True)
40
+ self.false_same_dir.mkdir(parents=True, exist_ok=True)
41
+ self.false_different_dir.mkdir(parents=True, exist_ok=True)
42
+
43
+ def run(self):
44
+ while True:
45
+ _signal = self.queue.get()
46
+ self.repo.git_pull()
47
+ self.repo.push_to_hub(commit_message=f"Upload data changes...")
48
+ print('Changes pushed to dataset!')
49
+
50
+
51
+ def load_enc_cls_model(file_name, key):
52
+ with open(file_name, "rb") as f:
53
+ data = f.read()
54
+ fernet = Fernet(key)
55
+ decrypted_data = fernet.decrypt(data)
56
+ decrypted_bytes = io.BytesIO(decrypted_data)
57
+ return torch.jit.load(decrypted_bytes)