Annas Dev commited on
Commit
319e2a1
1 Parent(s): 8f93744

finish vit

Browse files
app.py CHANGED
@@ -7,27 +7,9 @@ from src.similarity.similarity import Similarity
7
  similarity = Similarity()
8
  models = similarity.get_models()
9
 
10
- def check(img_main, img_1, img_2, model_idx):
11
- images = [
12
- (random.choice(
13
- [
14
- "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
15
- "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
16
- "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
17
- "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
18
- "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
19
- ]
20
- ), f"label {i}" if i != 0 else "label" * 50)
21
- for i in range(3)
22
- ]
23
- similarity.check_similarity([img_main, img_1, img_2], models[model_idx])
24
- return []
25
-
26
- # def greet(name):
27
- # return "Hello " + name + "!!"
28
-
29
- # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
30
- # iface.launch()
31
 
32
  with gr.Blocks() as demo:
33
  gr.Markdown('Checking Image Similarity')
@@ -41,7 +23,7 @@ with gr.Blocks() as demo:
41
  model = gr.Dropdown([m.name for m in models], label='Model', type='index')
42
 
43
  gallery = gr.Gallery(
44
- label="Generated images", show_label=True, elem_id="gallery"
45
  ).style(grid=[2], height="auto")
46
 
47
  submit_btn = gr.Button('Check Similarity')
 
7
  similarity = Similarity()
8
  models = similarity.get_models()
9
 
10
+ def check(img_main, img_1, img_2, model_idx):
11
+ result = similarity.check_similarity([img_main, img_1, img_2], models[model_idx])
12
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  with gr.Blocks() as demo:
15
  gr.Markdown('Checking Image Similarity')
 
23
  model = gr.Dropdown([m.name for m in models], label='Model', type='index')
24
 
25
  gallery = gr.Gallery(
26
+ label="Generated images", show_label=False, elem_id="gallery"
27
  ).style(grid=[2], height="auto")
28
 
29
  submit_btn = gr.Button('Check Similarity')
src/model/simlarity_model.py CHANGED
@@ -5,4 +5,5 @@ from .similarity_interface import SimilarityInterface
5
  class SimilarityModel:
6
  name: str
7
  image_size: int
8
- model_cls: SimilarityInterface
 
 
5
  class SimilarityModel:
6
  name: str
7
  image_size: int
8
+ model_cls: SimilarityInterface
9
+ image_input_type: str = 'array'
src/similarity/model_implements/vit_base.py CHANGED
@@ -1,21 +1,20 @@
1
- from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from PIL import Image
3
  import numpy as np
 
4
 
5
  class VitBase():
6
 
7
  def __init__(self):
8
- self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
9
- self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
10
 
11
  def extract_feature(self, imgs):
12
  features = []
13
  for img in imgs:
14
- feature = self.feature_extractor(images=img, return_tensors="tf")
15
- print('keys: ', feature.keys())
16
- f = self.model(feature)
17
- print('--> f', type(f))
18
- # print('type::', (feature['pixel_values'].shape))
19
- features.append(np.squeeze(feature['pixel_values']))
20
- print('shape:::',features[0].shape)
21
  return features
 
1
+ from transformers import ViTFeatureExtractor, ViTModel
2
  from PIL import Image
3
  import numpy as np
4
+ import torch
5
 
6
  class VitBase():
7
 
8
  def __init__(self):
9
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
10
+ self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
11
 
12
  def extract_feature(self, imgs):
13
  features = []
14
  for img in imgs:
15
+ inputs = self.feature_extractor(images=img, return_tensors="pt")
16
+ with torch.no_grad():
17
+ outputs = self.model(**inputs)
18
+ last_hidden_states = outputs.last_hidden_state
19
+ features.append(np.squeeze(last_hidden_states.numpy()).flatten())
 
 
20
  return features
src/similarity/similarity.py CHANGED
@@ -9,7 +9,7 @@ class Similarity:
9
  def get_models(self):
10
  return [
11
  model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
12
- model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase()),
13
  ]
14
 
15
  def check_similarity(self, img_urls, model):
@@ -17,14 +17,18 @@ class Similarity:
17
  imgs = []
18
  for url in img_urls:
19
  if url == "": continue
20
- imgs.append(image_util.load_image_url(url, required_size=(model.image_size, model.image_size)))
21
 
22
  features = model.model_cls.extract_feature(imgs)
 
23
  for i, v in enumerate(features):
24
  if i == 0: continue
25
  dist = matrix.cosine(features[0], v)
26
- # print(f'distance: {dist}')
 
 
 
27
 
28
- return 'oke'
29
 
30
 
 
9
  def get_models(self):
10
  return [
11
  model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
12
+ model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase(), image_input_type='pil'),
13
  ]
14
 
15
  def check_similarity(self, img_urls, model):
 
17
  imgs = []
18
  for url in img_urls:
19
  if url == "": continue
20
+ imgs.append(image_util.load_image_url(url, required_size=(model.image_size, model.image_size), image_type=model.image_input_type))
21
 
22
  features = model.model_cls.extract_feature(imgs)
23
+ results = []
24
  for i, v in enumerate(features):
25
  if i == 0: continue
26
  dist = matrix.cosine(features[0], v)
27
+ print(f'{i} -- distance: {dist}')
28
+ # results.append((imgs[i], f'similarity: {int(dist*100)}%'))
29
+ original_img = image_util.load_image_url(img_urls[i], required_size=None, image_type='pil')
30
+ results.append((original_img, f'similarity: {int(dist*100)}%'))
31
 
32
+ return results
33
 
34
 
src/util/image.py CHANGED
@@ -2,9 +2,12 @@ from PIL import Image
2
  import numpy as np
3
  import requests
4
 
5
- def load_image_url(url, required_size = (224,224)):
 
6
  img = Image.open(requests.get(url, stream=True).raw)
7
  img = Image.fromarray(np.array(img))
8
- img = img.resize(required_size)
9
- #img = (np.expand_dims(np.array(img), 0)/255).astype(np.float32)
 
 
10
  return img
 
2
  import numpy as np
3
  import requests
4
 
5
+ def load_image_url(url, required_size = (224,224), image_type = 'array'):
6
+ print(f'downloading.. {url}, type: {image_type}')
7
  img = Image.open(requests.get(url, stream=True).raw)
8
  img = Image.fromarray(np.array(img))
9
+ if required_size is not None:
10
+ img = img.resize(required_size)
11
+ if image_type == 'array':
12
+ img = (np.expand_dims(np.array(img), 0)/255).astype(np.float32)
13
  return img