theArijitDas commited on
Commit
faf09ce
1 Parent(s): f726795

Upload 3 files

Browse files
description_validator.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from transformers import AutoTokenizer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ import numpy as np
5
+
6
+ from warnings import filterwarnings
7
+ filterwarnings("ignore")
8
+
9
+ models = ["MPNet-base-v2", "DistilRoBERTa-v1", "MiniLM-L12-v2", "MiniLM-L6-v2"]
10
+ models_info = {
11
+ "MPNet-base-v2": {
12
+ "model_size": "420MB",
13
+ "model_url": "sentence-transformers/all-mpnet-base-v2",
14
+ "efficiency": "Moderate",
15
+ "chunk_size": 512
16
+ },
17
+ "DistilRoBERTa-v1": {
18
+ "model_size": "263MB",
19
+ "model_url": "sentence-transformers/all-distilroberta-v1",
20
+ "efficiency": "High",
21
+ "chunk_size": 512
22
+ },
23
+ "MiniLM-L12-v2": {
24
+ "model_size": "118MB",
25
+ "model_url": "sentence-transformers/all-MiniLM-L12-v2",
26
+ "efficiency": "High",
27
+ "chunk_size": 512
28
+ },
29
+ "MiniLM-L6-v2": {
30
+ "model_size": "82MB",
31
+ "model_url": "sentence-transformers/all-MiniLM-L6-v2",
32
+ "efficiency": "Very High",
33
+ "chunk_size": 512
34
+ }
35
+ }
36
+
37
+ class Description_Validator:
38
+ def __init__(self, model_name=None):
39
+ if model_name is None: model_name="DistilRoBERTa-v1"
40
+
41
+ self.model_info = models_info[model_name]
42
+ model_url = self.model_info["model_url"]
43
+
44
+ self.model = SentenceTransformer(model_url)
45
+ self.tokenizer = AutoTokenizer.from_pretrained(model_url)
46
+ self.chunk_size = self.model_info["chunk_size"]
47
+
48
+ def tokenize_and_chunk(self, text):
49
+ tokens = self.tokenizer(text, truncation=False, padding=True, add_special_tokens=False)['input_ids']
50
+ token_chunks = [tokens[i:i+self.chunk_size] for i in range(0, len(tokens), self.chunk_size)]
51
+ return token_chunks
52
+
53
+ def get_average_embedding(self, text):
54
+ token_chunks = self.tokenize_and_chunk(text)
55
+ chunk_embeddings = []
56
+ for chunk in token_chunks:
57
+ chunk_embedding = self.model.encode(self.tokenizer.decode(chunk), show_progress_bar=False)
58
+ chunk_embeddings.append(chunk_embedding)
59
+ return np.mean(chunk_embeddings, axis=0)
60
+
61
+ def similarity_score(self, desc1, desc2):
62
+ embedding1 = self.get_average_embedding(desc1).reshape(1, -1)
63
+ embedding2 = self.get_average_embedding(desc2).reshape(1, -1)
64
+ similarity = cosine_similarity(embedding1, embedding2)
65
+ return similarity[0][0]
image_validator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel, ViTImageProcessor, ViTModel
2
+ from PIL import Image
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+
5
+ from warnings import filterwarnings
6
+ filterwarnings("ignore")
7
+
8
+ models = ["CLIP-ViT Base", "ViT Base", "DINO ViT-S16"]
9
+ models_info = {
10
+ "CLIP-ViT Base": {
11
+ "model_size": "386MB",
12
+ "model_url": "openai/clip-vit-base-patch32",
13
+ "efficiency": "High",
14
+ },
15
+ "ViT Base": {
16
+ "model_size": "304MB",
17
+ "model_url": "google/vit-base-patch16-224",
18
+ "efficiency": "High",
19
+ },
20
+ "DINO ViT-S16": {
21
+ "model_size": "1.34GB",
22
+ "model_url": "facebook/dino-vits16",
23
+ "efficiency": "Moderate",
24
+ },
25
+ }
26
+
27
+ class Image_Validator:
28
+ def __init__(self, model_name=None):
29
+ if model_name is None: model_name="ViT Base"
30
+
31
+ self.model_info = models_info[model_name]
32
+ model_url = self.model_info["model_url"]
33
+
34
+ if model_name == "CLIP-ViT Base":
35
+ self.model = CLIPModel.from_pretrained(model_url)
36
+ self.processor = CLIPProcessor.from_pretrained(model_url)
37
+
38
+ elif model_name == "ViT Base":
39
+ self.model = ViTModel.from_pretrained(model_url)
40
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
41
+
42
+ elif model_name == "DINO ViT-S16":
43
+ self.model = ViTModel.from_pretrained(model_url)
44
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
45
+
46
+ def get_image_embedding(self, image_path):
47
+ image = Image.open(image_path)
48
+
49
+ # Process image according to the model
50
+ if hasattr(self, 'processor'): # CLIP models
51
+ inputs = self.processor(images=image, return_tensors="pt")
52
+ outputs = self.model.get_image_features(**inputs)
53
+
54
+ elif hasattr(self, 'feature_extractor'): # ViT models
55
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
56
+ outputs = self.model(**inputs).last_hidden_state
57
+
58
+ return outputs
59
+
60
+ def similarity_score(self, image_path_1, image_path_2):
61
+ embedding1 = self.get_image_embedding(image_path_1).reshape(1, -1)
62
+ embedding2 = self.get_image_embedding(image_path_2).reshape(1, -1)
63
+ similarity = cosine_similarity(embedding1.detach().numpy(), embedding2.detach().numpy())
64
+ return similarity[0][0]
product_update_validator.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model_factory.description_validator import Description_Validator
2
+ from model_factory.image_validator import Image_Validator
3
+
4
+ class Update_Validator:
5
+ def __init__(self, text_model=None, image_model=None, threshold=0.7):
6
+ self.description_validator = Description_Validator(model_name=text_model)
7
+ self.image_validator = Image_Validator(model_name=image_model)
8
+ self.threshold = threshold
9
+
10
+ def validate(self, text1, text2, image_path_1, image_path_2, threshold=None, return_score=False):
11
+ description_similarity = self.description_validator.similarity_score(text1, text2)
12
+ image_similarity = self.image_validator.similarity_score(image_path_1, image_path_2)
13
+ similarity_score = 0.75 * description_similarity + 0.25 * image_similarity
14
+
15
+ if threshold is None: threshold=self.threshold
16
+ label = True if similarity_score >= threshold else False
17
+
18
+ if return_score:
19
+ return {'score':similarity_score, 'label':label}
20
+ else:
21
+ return {'label':label}