Zero-Shot Image Classification
Transformers
Safetensors
clip
Inference Endpoints
Jsonwu commited on
Commit
e30bb8f
·
verified ·
1 Parent(s): a0bbb2d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +114 -1
README.md CHANGED
@@ -23,4 +23,117 @@ User interface (UI) design is a difficult yet important task for ensuring the us
23
  - **Developed by:** BigLab
24
  - **Model type:** CLIP-style Multi-modal Dual-encoder Transformer
25
  - **Language(s) (NLP):** English
26
- - **License:** MIT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  - **Developed by:** BigLab
24
  - **Model type:** CLIP-style Multi-modal Dual-encoder Transformer
25
  - **Language(s) (NLP):** English
26
+ - **License:** MIT
27
+
28
+
29
+ ```python
30
+ import torch
31
+ from transformers import CLIPProcessor, CLIPModel
32
+
33
+ IMG_SIZE = 224
34
+ DEVICE = "cpu" # can also be "cuda" or "mps"
35
+ LOGIT_SCALE = 100 # based on OpenAI's CLIP example code
36
+ NORMALIZE_SCORING = True
37
+
38
+ model_path="uiclip_jitteredwebsites-2-224-paraphrased" # can also be webpairs or human pairs variants
39
+ processor_path="openai/clip-vit-base-patch32"
40
+
41
+ model = CLIPModel.from_pretrained(model_path)
42
+ model = model.eval()
43
+ model = model.to(DEVICE)
44
+
45
+ processor = CLIPProcessor.from_pretrained(processor_path)
46
+
47
+ def compute_quality_scores(input_list):
48
+ # input_list is a list of types where the first element is a description and the second is a PIL image
49
+ description_list = ["ui screenshot. well-designed. " + input_item[0] for input_item in input_list]
50
+ img_list = [input_item[1] for input_item in input_list]
51
+ text_embeddings_tensor = compute_description_embeddings(description_list) # B x H
52
+ img_embeddings_tensor = compute_image_embeddings(img_list) # B x H
53
+
54
+ # normalize tensors
55
+ text_embeddings_tensor /= text_embeddings_tensor.norm(dim=-1, keepdim=True)
56
+ img_embeddings_tensor /= img_embeddings_tensor.norm(dim=-1, keepdim=True)
57
+
58
+ if NORMALIZE_SCORING:
59
+ text_embeddings_tensor_poor = compute_description_embeddings([d.replace("well-designed. ", "poor design. ") for d in description_list]) # B x H
60
+ text_embeddings_tensor_poor /= text_embeddings_tensor_poor.norm(dim=-1, keepdim=True)
61
+ text_embeddings_tensor_all = torch.stack((text_embeddings_tensor, text_embeddings_tensor_poor), dim=1) # B x 2 x H
62
+ else:
63
+ text_embeddings_tensor_all = text_embeddings_tensor.unsqueeze(1)
64
+
65
+ img_embeddings_tensor = img_embeddings_tensor.unsqueeze(1) # B x 1 x H
66
+
67
+ scores = (LOGIT_SCALE * img_embeddings_tensor @ text_embeddings_tensor_all.permute(0, 2, 1)).squeeze(1)
68
+
69
+ if NORMALIZE_SCORING:
70
+ scores = scores.softmax(dim=-1)
71
+
72
+ return scores[:, 0]
73
+
74
+ def compute_description_embeddings(descriptions):
75
+ inputs = processor(text=descriptions, return_tensors="pt", padding=True)
76
+ inputs['input_ids'] = inputs['input_ids'].to(DEVICE)
77
+ inputs['attention_mask'] = inputs['attention_mask'].to(DEVICE)
78
+ text_embedding = model.get_text_features(**inputs)
79
+ return text_embedding
80
+
81
+ def compute_image_embeddings(image_list):
82
+ windowed_batch = [slide_window_over_image(img, IMG_SIZE) for img in image_list]
83
+ inds = []
84
+ for imgi in range(len(windowed_batch)):
85
+ inds.append([imgi for _ in windowed_batch[imgi]])
86
+
87
+ processed_batch = [item for sublist in windowed_batch for item in sublist]
88
+ inputs = processor(images=processed_batch, return_tensors="pt")
89
+ # run all sub windows of all images in batch through the model
90
+ inputs['pixel_values'] = inputs['pixel_values'].to(DEVICE)
91
+ with torch.no_grad():
92
+ image_features = model.get_image_features(**inputs)
93
+
94
+ # output contains all subwindows, need to mask for each image
95
+ processed_batch_inds = torch.tensor([item for sublist in inds for item in sublist]).long().to(image_features.device)
96
+ embed_list = []
97
+ for i in range(len(image_list)):
98
+ mask = processed_batch_inds == i
99
+ embed_list.append(image_features[mask].mean(dim=0))
100
+ image_embedding = torch.stack(embed_list, dim=0)
101
+ return image_embedding
102
+
103
+ def preresize_image(image, image_size):
104
+ aspect_ratio = image.width / image.height
105
+ if aspect_ratio > 1:
106
+ image = image.resize((int(aspect_ratio * image_size), image_size))
107
+ else:
108
+ image = image.resize((image_size, int(image_size / aspect_ratio)))
109
+ return image
110
+
111
+ def slide_window_over_image(input_image, img_size):
112
+ input_image = preresize_image(input_image, img_size)
113
+ width, height = input_image.size
114
+ square_size = min(width, height)
115
+ longer_dimension = max(width, height)
116
+ num_steps = (longer_dimension + square_size - 1) // square_size
117
+
118
+ if num_steps > 1:
119
+ step_size = (longer_dimension - square_size) // (num_steps - 1)
120
+ else:
121
+ step_size = square_size
122
+
123
+ cropped_images = []
124
+
125
+ for y in range(0, height - square_size + 1, step_size if height > width else square_size):
126
+ for x in range(0, width - square_size + 1, step_size if width > height else square_size):
127
+ left = x
128
+ upper = y
129
+ right = x + square_size
130
+ lower = y + square_size
131
+ cropped_image = input_image.crop((left, upper, right, lower))
132
+ cropped_images.append(cropped_image)
133
+
134
+ return cropped_images
135
+
136
+
137
+ # compute the quality scores for a list of descriptions (strings) and images (PIL images)
138
+ prediction_scores = compute_quality_scores(list(zip(test_descriptions, test_images)))
139
+ ```