Spaces:
Runtime error
Runtime error
Dominik Hintersdorf
commited on
Commit
β’
3ffe17d
1
Parent(s):
ceb330a
added additional models
Browse files- app.py +46 -12
- calculate_text_embeddings.ipynb +111 -130
- prompt_text_embeddings/{ViT-B-16_prompt_text_embeddings.pt β ViT-B-16_laion400m_prompt_text_embeddings.pt} +2 -2
- prompt_text_embeddings/{ViT-B-32_prompt_text_embeddings.pt β ViT-B-16_openai_prompt_text_embeddings.pt} +2 -2
- prompt_text_embeddings/{ViT-L-14_prompt_text_embeddings.pt β ViT-B-32_laion2b_prompt_text_embeddings.pt} +2 -2
- prompt_text_embeddings/ViT-B-32_laion400m_prompt_text_embeddings.pt +3 -0
- prompt_text_embeddings/ViT-B-32_openai_prompt_text_embeddings.pt +3 -0
- prompt_text_embeddings/ViT-L-14_laion2b_prompt_text_embeddings.pt +3 -0
- prompt_text_embeddings/ViT-L-14_laion400m_prompt_text_embeddings.pt +3 -0
- prompt_text_embeddings/ViT-L-14_openai_prompt_text_embeddings.pt +3 -0
app.py
CHANGED
@@ -39,7 +39,9 @@ PROMPTS = [
|
|
39 |
'{0} in a suit',
|
40 |
'{0} in a dress'
|
41 |
]
|
42 |
-
|
|
|
|
|
43 |
NUM_TOTAL_NAMES = 1_000
|
44 |
SEED = 42
|
45 |
MIN_NUM_CORRECT_PROMPT_PREDS = 1
|
@@ -52,7 +54,7 @@ EXAMPLE_IMAGE_URLS = read_actor_files(EDAMPLE_IMAGE_DIR)
|
|
52 |
save_images_to_folder(os.path.join(EDAMPLE_IMAGE_DIR, 'images'), EXAMPLE_IMAGE_URLS)
|
53 |
|
54 |
MODELS = {}
|
55 |
-
for model_name in
|
56 |
dataset = 'LAION400M'
|
57 |
model, _, preprocess = open_clip.create_model_and_transforms(
|
58 |
model_name,
|
@@ -63,24 +65,55 @@ for model_name in OPEN_CLIP_MODEL_NAMES:
|
|
63 |
'model_instance': model,
|
64 |
'preprocessing': preprocess,
|
65 |
'model_name': model_name,
|
66 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
}
|
68 |
|
69 |
FULL_NAMES_DF = pd.read_csv('full_names.csv', index_col=0)
|
70 |
LAION_MEMBERSHIP_OCCURENCE = pd.read_csv('laion_membership_occurence_count.csv', index_col=0)
|
71 |
|
72 |
EXAMPLE_ACTORS_BY_MODEL = {
|
73 |
-
"ViT-B-32": ["T._J._Thyne"],
|
74 |
-
"ViT-B-16": ["Barbara_SchΓΆneberger", "Carolin_Kebekus"],
|
75 |
-
"ViT-L-14": ["Max_Giermann", "Nicole_De_Boer"]
|
76 |
}
|
77 |
|
78 |
EXAMPLES = []
|
79 |
-
for model_name, person_names in EXAMPLE_ACTORS_BY_MODEL.items():
|
80 |
for name in person_names:
|
81 |
image_folder = os.path.join("./example_images/images/", name)
|
82 |
for dd_model_name in MODELS.keys():
|
83 |
-
if model_name
|
84 |
continue
|
85 |
|
86 |
EXAMPLES.append([
|
@@ -139,7 +172,7 @@ CSS = """
|
|
139 |
transform: translateY(10px);
|
140 |
background: white;
|
141 |
}
|
142 |
-
|
143 |
.dark .footer {
|
144 |
border-color: #303030;
|
145 |
}
|
@@ -221,8 +254,8 @@ gr.Files.preprocess = preprocess
|
|
221 |
|
222 |
@torch.no_grad()
|
223 |
def calculate_text_embeddings(model_name, prompts):
|
224 |
-
tokenizer =
|
225 |
-
context_vecs =
|
226 |
|
227 |
model_instance = MODELS[model_name]['model_instance']
|
228 |
|
@@ -509,7 +542,8 @@ with block as demo:
|
|
509 |
with gr.Column():
|
510 |
model_dd = gr.Dropdown(label="CLIP Model", choices=list(MODELS.keys()),
|
511 |
value=list(MODELS.keys())[0])
|
512 |
-
true_name = gr.Textbox(label='Name of Person:', lines=1, value=DEFAULT_INITIAL_NAME
|
|
|
513 |
prompts = gr.Dataframe(
|
514 |
value=[[x.format(DEFAULT_INITIAL_NAME) for x in PROMPTS]],
|
515 |
label='Prompts Used (hold shift to scroll sideways):',
|
|
|
39 |
'{0} in a suit',
|
40 |
'{0} in a dress'
|
41 |
]
|
42 |
+
OPEN_CLIP_LAION400M_MODEL_NAMES = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']
|
43 |
+
OPEN_CLIP_LAION2B_MODEL_NAMES = [('ViT-B-32', 'laion2b_s34b_b79k'), ('ViT-L-14', 'laion2b_s32b_b82k')]
|
44 |
+
OPEN_AI_MODELS = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']
|
45 |
NUM_TOTAL_NAMES = 1_000
|
46 |
SEED = 42
|
47 |
MIN_NUM_CORRECT_PROMPT_PREDS = 1
|
|
|
54 |
save_images_to_folder(os.path.join(EDAMPLE_IMAGE_DIR, 'images'), EXAMPLE_IMAGE_URLS)
|
55 |
|
56 |
MODELS = {}
|
57 |
+
for model_name in OPEN_CLIP_LAION400M_MODEL_NAMES:
|
58 |
dataset = 'LAION400M'
|
59 |
model, _, preprocess = open_clip.create_model_and_transforms(
|
60 |
model_name,
|
|
|
65 |
'model_instance': model,
|
66 |
'preprocessing': preprocess,
|
67 |
'model_name': model_name,
|
68 |
+
'tokenizer': open_clip.get_tokenizer(model_name),
|
69 |
+
'prompt_text_embeddings': torch.load(f'./prompt_text_embeddings/{model_name}_{dataset.lower()}_prompt_text_embeddings.pt')
|
70 |
+
}
|
71 |
+
|
72 |
+
for model_name, dataset_name in OPEN_CLIP_LAION2B_MODEL_NAMES:
|
73 |
+
dataset = 'LAION2B'
|
74 |
+
model, _, preprocess = open_clip.create_model_and_transforms(
|
75 |
+
model_name,
|
76 |
+
pretrained=dataset_name
|
77 |
+
)
|
78 |
+
model = model.eval()
|
79 |
+
MODELS[f'OpenClip {model_name} trained on {dataset}'] = {
|
80 |
+
'model_instance': model,
|
81 |
+
'preprocessing': preprocess,
|
82 |
+
'model_name': model_name,
|
83 |
+
'tokenizer': open_clip.get_tokenizer(model_name),
|
84 |
+
'prompt_text_embeddings': torch.load(f'./prompt_text_embeddings/{model_name}_{dataset.lower()}_prompt_text_embeddings.pt')
|
85 |
+
}
|
86 |
+
|
87 |
+
for model_name in OPEN_AI_MODELS:
|
88 |
+
dataset = 'OpenAI'
|
89 |
+
model, _, preprocess = open_clip.create_model_and_transforms(
|
90 |
+
model_name,
|
91 |
+
pretrained=dataset.lower()
|
92 |
+
)
|
93 |
+
model = model.eval()
|
94 |
+
MODELS[f'OpenClip {model_name} trained by {dataset}'] = {
|
95 |
+
'model_instance': model,
|
96 |
+
'preprocessing': preprocess,
|
97 |
+
'model_name': model_name,
|
98 |
+
'tokenizer': open_clip.get_tokenizer(model_name),
|
99 |
+
'prompt_text_embeddings': torch.load(f'./prompt_text_embeddings/{model_name}_{dataset.lower()}_prompt_text_embeddings.pt')
|
100 |
}
|
101 |
|
102 |
FULL_NAMES_DF = pd.read_csv('full_names.csv', index_col=0)
|
103 |
LAION_MEMBERSHIP_OCCURENCE = pd.read_csv('laion_membership_occurence_count.csv', index_col=0)
|
104 |
|
105 |
EXAMPLE_ACTORS_BY_MODEL = {
|
106 |
+
("ViT-B-32", "laion400m"): ["T._J._Thyne"],
|
107 |
+
("ViT-B-16", "laion400m"): ["Barbara_SchΓΆneberger", "Carolin_Kebekus"],
|
108 |
+
("ViT-L-14", "laion400m"): ["Max_Giermann", "Nicole_De_Boer"]
|
109 |
}
|
110 |
|
111 |
EXAMPLES = []
|
112 |
+
for (model_name, dataset_name), person_names in EXAMPLE_ACTORS_BY_MODEL.items():
|
113 |
for name in person_names:
|
114 |
image_folder = os.path.join("./example_images/images/", name)
|
115 |
for dd_model_name in MODELS.keys():
|
116 |
+
if not (model_name.lower() in dd_model_name.lower() and dataset_name.lower() in dd_model_name.lower()):
|
117 |
continue
|
118 |
|
119 |
EXAMPLES.append([
|
|
|
172 |
transform: translateY(10px);
|
173 |
background: white;
|
174 |
}
|
175 |
+
|
176 |
.dark .footer {
|
177 |
border-color: #303030;
|
178 |
}
|
|
|
254 |
|
255 |
@torch.no_grad()
|
256 |
def calculate_text_embeddings(model_name, prompts):
|
257 |
+
tokenizer = MODELS[model_name]['tokenizer']
|
258 |
+
context_vecs = tokenizer(prompts)
|
259 |
|
260 |
model_instance = MODELS[model_name]['model_instance']
|
261 |
|
|
|
542 |
with gr.Column():
|
543 |
model_dd = gr.Dropdown(label="CLIP Model", choices=list(MODELS.keys()),
|
544 |
value=list(MODELS.keys())[0])
|
545 |
+
true_name = gr.Textbox(label='Name of Person (make sure it matches the prompts):', lines=1, value=DEFAULT_INITIAL_NAME,
|
546 |
+
every=5)
|
547 |
prompts = gr.Dataframe(
|
548 |
value=[[x.format(DEFAULT_INITIAL_NAME) for x in PROMPTS]],
|
549 |
label='Prompts Used (hold shift to scroll sideways):',
|
calculate_text_embeddings.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {
|
7 |
"collapsed": true
|
8 |
},
|
@@ -39,33 +39,70 @@
|
|
39 |
" '{0} in a suit',\n",
|
40 |
" '{0} in a dress'\n",
|
41 |
"]\n",
|
42 |
-
"
|
|
|
|
|
43 |
"SEED = 42"
|
44 |
]
|
45 |
},
|
46 |
{
|
47 |
"cell_type": "code",
|
48 |
-
"execution_count":
|
|
|
|
|
|
|
49 |
"outputs": [],
|
50 |
"source": [
|
51 |
-
"
|
52 |
-
"
|
53 |
-
"
|
54 |
-
"
|
55 |
-
"
|
56 |
-
"
|
57 |
-
"
|
58 |
" model = model.eval()\n",
|
59 |
-
"
|
60 |
-
"
|
61 |
-
|
62 |
-
|
63 |
-
"
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
},
|
66 |
{
|
67 |
"cell_type": "code",
|
68 |
-
"execution_count":
|
|
|
|
|
|
|
69 |
"outputs": [],
|
70 |
"source": [
|
71 |
"# define a function to get the predictions for an actor/actress\n",
|
@@ -90,50 +127,30 @@
|
|
90 |
" text_features = torch.cat(text_features).view(list(context.shape[:-1]) + [-1])\n",
|
91 |
"\n",
|
92 |
" return text_features"
|
93 |
-
]
|
94 |
-
"metadata": {
|
95 |
-
"collapsed": false
|
96 |
-
}
|
97 |
},
|
98 |
{
|
99 |
"cell_type": "code",
|
100 |
-
"execution_count":
|
101 |
-
"
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>first_name</th>\n <th>sex</th>\n <th>last_name</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>Eliana</td>\n <td>f</td>\n <td>Cardenas</td>\n </tr>\n <tr>\n <th>1</th>\n <td>Meghann</td>\n <td>f</td>\n <td>Daniels</td>\n </tr>\n <tr>\n <th>2</th>\n <td>Ada</td>\n <td>f</td>\n <td>Stevenson</td>\n </tr>\n <tr>\n <th>3</th>\n <td>Elsa</td>\n <td>f</td>\n <td>Leblanc</td>\n </tr>\n <tr>\n <th>4</th>\n <td>Avah</td>\n <td>f</td>\n <td>Lambert</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>9995</th>\n <td>Kasen</td>\n <td>m</td>\n <td>Barker</td>\n </tr>\n <tr>\n <th>9996</th>\n <td>Camryn</td>\n <td>m</td>\n <td>Roberts</td>\n </tr>\n <tr>\n <th>9997</th>\n <td>Henry</td>\n <td>m</td>\n <td>Whitaker</td>\n </tr>\n <tr>\n <th>9998</th>\n <td>Adin</td>\n <td>m</td>\n <td>Richards</td>\n </tr>\n <tr>\n <th>9999</th>\n <td>Charley</td>\n <td>m</td>\n <td>Herman</td>\n </tr>\n </tbody>\n</table>\n<p>10000 rows Γ 3 columns</p>\n</div>"
|
106 |
-
},
|
107 |
-
"execution_count": 4,
|
108 |
-
"metadata": {},
|
109 |
-
"output_type": "execute_result"
|
110 |
-
}
|
111 |
-
],
|
112 |
"source": [
|
113 |
"# load the possible names\n",
|
114 |
"possible_names = pd.read_csv('./full_names.csv', index_col=0)\n",
|
115 |
"possible_names\n",
|
116 |
"# possible_names_list = (possible_names['first_name'] + ' ' + possible_names['last_name']).tolist()\n",
|
117 |
"# possible_names_list[:5]"
|
118 |
-
]
|
119 |
-
"metadata": {
|
120 |
-
"collapsed": false
|
121 |
-
}
|
122 |
},
|
123 |
{
|
124 |
"cell_type": "code",
|
125 |
-
"execution_count":
|
126 |
-
"
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>first_name</th>\n <th>sex</th>\n <th>last_name</th>\n <th>prompt_0</th>\n <th>prompt_1</th>\n <th>prompt_2</th>\n <th>prompt_3</th>\n <th>prompt_4</th>\n <th>prompt_5</th>\n <th>prompt_6</th>\n <th>...</th>\n <th>prompt_11</th>\n <th>prompt_12</th>\n <th>prompt_13</th>\n <th>prompt_14</th>\n <th>prompt_15</th>\n <th>prompt_16</th>\n <th>prompt_17</th>\n <th>prompt_18</th>\n <th>prompt_19</th>\n <th>prompt_20</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>Eliana</td>\n <td>f</td>\n <td>Cardenas</td>\n <td>Eliana Cardenas</td>\n <td>an image of Eliana Cardenas</td>\n <td>a photo of Eliana Cardenas</td>\n <td>Eliana Cardenas on a photo</td>\n <td>a photo of a person named Eliana Cardenas</td>\n <td>a person named Eliana Cardenas</td>\n <td>a man named Eliana Cardenas</td>\n <td>...</td>\n <td>a photo of the celebrity Eliana Cardenas</td>\n <td>actor Eliana Cardenas</td>\n <td>actress Eliana Cardenas</td>\n <td>a colored photo of Eliana Cardenas</td>\n <td>a black and white photo of Eliana Cardenas</td>\n <td>a cool photo of Eliana Cardenas</td>\n <td>a cropped photo of Eliana Cardenas</td>\n <td>a cropped image of Eliana Cardenas</td>\n <td>Eliana Cardenas in a suit</td>\n <td>Eliana Cardenas in a dress</td>\n </tr>\n <tr>\n <th>1</th>\n <td>Meghann</td>\n <td>f</td>\n <td>Daniels</td>\n <td>Meghann Daniels</td>\n <td>an image of Meghann Daniels</td>\n <td>a photo of Meghann Daniels</td>\n <td>Meghann Daniels on a photo</td>\n <td>a photo of a person named Meghann Daniels</td>\n <td>a person named Meghann Daniels</td>\n <td>a man named Meghann Daniels</td>\n <td>...</td>\n <td>a photo of the celebrity Meghann Daniels</td>\n <td>actor Meghann Daniels</td>\n <td>actress Meghann Daniels</td>\n <td>a colored photo of Meghann Daniels</td>\n <td>a black and white photo of Meghann Daniels</td>\n <td>a cool photo of Meghann Daniels</td>\n <td>a cropped photo of Meghann Daniels</td>\n <td>a cropped image of Meghann Daniels</td>\n <td>Meghann Daniels in a suit</td>\n <td>Meghann Daniels in a dress</td>\n </tr>\n <tr>\n <th>2</th>\n <td>Ada</td>\n <td>f</td>\n <td>Stevenson</td>\n <td>Ada Stevenson</td>\n <td>an image of Ada Stevenson</td>\n <td>a photo of Ada Stevenson</td>\n <td>Ada Stevenson on a photo</td>\n <td>a photo of a person named Ada Stevenson</td>\n <td>a person named Ada Stevenson</td>\n <td>a man named Ada Stevenson</td>\n <td>...</td>\n <td>a photo of the celebrity Ada Stevenson</td>\n <td>actor Ada Stevenson</td>\n <td>actress Ada Stevenson</td>\n <td>a colored photo of Ada Stevenson</td>\n <td>a black and white photo of Ada Stevenson</td>\n <td>a cool photo of Ada Stevenson</td>\n <td>a cropped photo of Ada Stevenson</td>\n <td>a cropped image of Ada Stevenson</td>\n <td>Ada Stevenson in a suit</td>\n <td>Ada Stevenson in a dress</td>\n </tr>\n <tr>\n <th>3</th>\n <td>Elsa</td>\n <td>f</td>\n <td>Leblanc</td>\n <td>Elsa Leblanc</td>\n <td>an image of Elsa Leblanc</td>\n <td>a photo of Elsa Leblanc</td>\n <td>Elsa Leblanc on a photo</td>\n <td>a photo of a person named Elsa Leblanc</td>\n <td>a person named Elsa Leblanc</td>\n <td>a man named Elsa Leblanc</td>\n <td>...</td>\n <td>a photo of the celebrity Elsa Leblanc</td>\n <td>actor Elsa Leblanc</td>\n <td>actress Elsa Leblanc</td>\n <td>a colored photo of Elsa Leblanc</td>\n <td>a black and white photo of Elsa Leblanc</td>\n <td>a cool photo of Elsa Leblanc</td>\n <td>a cropped photo of Elsa Leblanc</td>\n <td>a cropped image of Elsa Leblanc</td>\n <td>Elsa Leblanc in a suit</td>\n <td>Elsa Leblanc in a dress</td>\n </tr>\n <tr>\n <th>4</th>\n <td>Avah</td>\n <td>f</td>\n <td>Lambert</td>\n <td>Avah Lambert</td>\n <td>an image of Avah Lambert</td>\n <td>a photo of Avah Lambert</td>\n <td>Avah Lambert on a photo</td>\n <td>a photo of a person named Avah Lambert</td>\n <td>a person named Avah Lambert</td>\n <td>a man named Avah Lambert</td>\n <td>...</td>\n <td>a photo of the celebrity Avah Lambert</td>\n <td>actor Avah Lambert</td>\n <td>actress Avah Lambert</td>\n <td>a colored photo of Avah Lambert</td>\n <td>a black and white photo of Avah Lambert</td>\n <td>a cool photo of Avah Lambert</td>\n <td>a cropped photo of Avah Lambert</td>\n <td>a cropped image of Avah Lambert</td>\n <td>Avah Lambert in a suit</td>\n <td>Avah Lambert in a dress</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>9995</th>\n <td>Kasen</td>\n <td>m</td>\n <td>Barker</td>\n <td>Kasen Barker</td>\n <td>an image of Kasen Barker</td>\n <td>a photo of Kasen Barker</td>\n <td>Kasen Barker on a photo</td>\n <td>a photo of a person named Kasen Barker</td>\n <td>a person named Kasen Barker</td>\n <td>a man named Kasen Barker</td>\n <td>...</td>\n <td>a photo of the celebrity Kasen Barker</td>\n <td>actor Kasen Barker</td>\n <td>actress Kasen Barker</td>\n <td>a colored photo of Kasen Barker</td>\n <td>a black and white photo of Kasen Barker</td>\n <td>a cool photo of Kasen Barker</td>\n <td>a cropped photo of Kasen Barker</td>\n <td>a cropped image of Kasen Barker</td>\n <td>Kasen Barker in a suit</td>\n <td>Kasen Barker in a dress</td>\n </tr>\n <tr>\n <th>9996</th>\n <td>Camryn</td>\n <td>m</td>\n <td>Roberts</td>\n <td>Camryn Roberts</td>\n <td>an image of Camryn Roberts</td>\n <td>a photo of Camryn Roberts</td>\n <td>Camryn Roberts on a photo</td>\n <td>a photo of a person named Camryn Roberts</td>\n <td>a person named Camryn Roberts</td>\n <td>a man named Camryn Roberts</td>\n <td>...</td>\n <td>a photo of the celebrity Camryn Roberts</td>\n <td>actor Camryn Roberts</td>\n <td>actress Camryn Roberts</td>\n <td>a colored photo of Camryn Roberts</td>\n <td>a black and white photo of Camryn Roberts</td>\n <td>a cool photo of Camryn Roberts</td>\n <td>a cropped photo of Camryn Roberts</td>\n <td>a cropped image of Camryn Roberts</td>\n <td>Camryn Roberts in a suit</td>\n <td>Camryn Roberts in a dress</td>\n </tr>\n <tr>\n <th>9997</th>\n <td>Henry</td>\n <td>m</td>\n <td>Whitaker</td>\n <td>Henry Whitaker</td>\n <td>an image of Henry Whitaker</td>\n <td>a photo of Henry Whitaker</td>\n <td>Henry Whitaker on a photo</td>\n <td>a photo of a person named Henry Whitaker</td>\n <td>a person named Henry Whitaker</td>\n <td>a man named Henry Whitaker</td>\n <td>...</td>\n <td>a photo of the celebrity Henry Whitaker</td>\n <td>actor Henry Whitaker</td>\n <td>actress Henry Whitaker</td>\n <td>a colored photo of Henry Whitaker</td>\n <td>a black and white photo of Henry Whitaker</td>\n <td>a cool photo of Henry Whitaker</td>\n <td>a cropped photo of Henry Whitaker</td>\n <td>a cropped image of Henry Whitaker</td>\n <td>Henry Whitaker in a suit</td>\n <td>Henry Whitaker in a dress</td>\n </tr>\n <tr>\n <th>9998</th>\n <td>Adin</td>\n <td>m</td>\n <td>Richards</td>\n <td>Adin Richards</td>\n <td>an image of Adin Richards</td>\n <td>a photo of Adin Richards</td>\n <td>Adin Richards on a photo</td>\n <td>a photo of a person named Adin Richards</td>\n <td>a person named Adin Richards</td>\n <td>a man named Adin Richards</td>\n <td>...</td>\n <td>a photo of the celebrity Adin Richards</td>\n <td>actor Adin Richards</td>\n <td>actress Adin Richards</td>\n <td>a colored photo of Adin Richards</td>\n <td>a black and white photo of Adin Richards</td>\n <td>a cool photo of Adin Richards</td>\n <td>a cropped photo of Adin Richards</td>\n <td>a cropped image of Adin Richards</td>\n <td>Adin Richards in a suit</td>\n <td>Adin Richards in a dress</td>\n </tr>\n <tr>\n <th>9999</th>\n <td>Charley</td>\n <td>m</td>\n <td>Herman</td>\n <td>Charley Herman</td>\n <td>an image of Charley Herman</td>\n <td>a photo of Charley Herman</td>\n <td>Charley Herman on a photo</td>\n <td>a photo of a person named Charley Herman</td>\n <td>a person named Charley Herman</td>\n <td>a man named Charley Herman</td>\n <td>...</td>\n <td>a photo of the celebrity Charley Herman</td>\n <td>actor Charley Herman</td>\n <td>actress Charley Herman</td>\n <td>a colored photo of Charley Herman</td>\n <td>a black and white photo of Charley Herman</td>\n <td>a cool photo of Charley Herman</td>\n <td>a cropped photo of Charley Herman</td>\n <td>a cropped image of Charley Herman</td>\n <td>Charley Herman in a suit</td>\n <td>Charley Herman in a dress</td>\n </tr>\n </tbody>\n</table>\n<p>10000 rows Γ 24 columns</p>\n</div>"
|
131 |
-
},
|
132 |
-
"execution_count": 5,
|
133 |
-
"metadata": {},
|
134 |
-
"output_type": "execute_result"
|
135 |
-
}
|
136 |
-
],
|
137 |
"source": [
|
138 |
"# populate the prompts with the possible names\n",
|
139 |
"prompts = []\n",
|
@@ -145,119 +162,83 @@
|
|
145 |
" prompts.append(df_dict)\n",
|
146 |
"prompts = pd.DataFrame(prompts)\n",
|
147 |
"prompts"
|
148 |
-
]
|
149 |
-
"metadata": {
|
150 |
-
"collapsed": false
|
151 |
-
}
|
152 |
},
|
153 |
{
|
154 |
"cell_type": "code",
|
155 |
-
"execution_count":
|
156 |
-
"outputs": [],
|
157 |
-
"source": [
|
158 |
-
"label_context_vecs = []\n",
|
159 |
-
"for i in range(len(PROMPTS)):\n",
|
160 |
-
" context = open_clip.tokenize(prompts[f'prompt_{i}'].to_numpy())\n",
|
161 |
-
" label_context_vecs.append(context)\n",
|
162 |
-
"label_context_vecs = torch.stack(label_context_vecs)"
|
163 |
-
],
|
164 |
"metadata": {
|
165 |
"collapsed": false
|
166 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
},
|
168 |
{
|
169 |
"cell_type": "code",
|
170 |
-
"execution_count":
|
171 |
-
"
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
"application/vnd.jupyter.widget-view+json": {
|
176 |
-
"version_major": 2,
|
177 |
-
"version_minor": 0,
|
178 |
-
"model_id": "4267d43b498f481db5cbf7e709c9ace3"
|
179 |
-
}
|
180 |
-
},
|
181 |
-
"metadata": {},
|
182 |
-
"output_type": "display_data"
|
183 |
-
},
|
184 |
-
{
|
185 |
-
"data": {
|
186 |
-
"text/plain": "Calculating Text Embeddings: 0%| | 0/210 [00:00<?, ?it/s]",
|
187 |
-
"application/vnd.jupyter.widget-view+json": {
|
188 |
-
"version_major": 2,
|
189 |
-
"version_minor": 0,
|
190 |
-
"model_id": "34a21714ab4d42b2beaa3024bcdd8fdd"
|
191 |
-
}
|
192 |
-
},
|
193 |
-
"metadata": {},
|
194 |
-
"output_type": "display_data"
|
195 |
-
},
|
196 |
-
{
|
197 |
-
"data": {
|
198 |
-
"text/plain": "Calculating Text Embeddings: 0%| | 0/210 [00:00<?, ?it/s]",
|
199 |
-
"application/vnd.jupyter.widget-view+json": {
|
200 |
-
"version_major": 2,
|
201 |
-
"version_minor": 0,
|
202 |
-
"model_id": "3278ad478d7d455da8b03d954fbc4558"
|
203 |
-
}
|
204 |
-
},
|
205 |
-
"metadata": {},
|
206 |
-
"output_type": "display_data"
|
207 |
-
}
|
208 |
-
],
|
209 |
"source": [
|
210 |
-
"label_context_vecs = label_context_vecs.to(device)\n",
|
211 |
-
"\n",
|
212 |
"text_embeddings_per_model = {}\n",
|
213 |
-
"for
|
|
|
|
|
214 |
" model = model.to(device)\n",
|
215 |
-
" text_embeddings = get_text_embeddings(model, label_context_vecs, use_tqdm=True, context_batchsize=
|
216 |
-
" text_embeddings_per_model[
|
217 |
" model = model.cpu()\n",
|
|
|
218 |
"\n",
|
219 |
"label_context_vecs = label_context_vecs.cpu()"
|
220 |
-
]
|
221 |
-
"metadata": {
|
222 |
-
"collapsed": false
|
223 |
-
}
|
224 |
},
|
225 |
{
|
226 |
"cell_type": "code",
|
227 |
-
"execution_count":
|
|
|
|
|
|
|
228 |
"outputs": [],
|
229 |
"source": [
|
230 |
"# save the calculated embeddings to a file\n",
|
231 |
"if not os.path.exists('./prompt_text_embeddings'):\n",
|
232 |
" os.makedirs('./prompt_text_embeddings')"
|
233 |
-
]
|
234 |
-
"metadata": {
|
235 |
-
"collapsed": false
|
236 |
-
}
|
237 |
},
|
238 |
{
|
239 |
"cell_type": "code",
|
240 |
-
"execution_count":
|
|
|
|
|
|
|
241 |
"outputs": [],
|
242 |
"source": [
|
243 |
-
"for model_name,
|
244 |
" torch.save(\n",
|
245 |
-
" text_embeddings_per_model[model_name],\n",
|
246 |
-
" f'./prompt_text_embeddings/{model_name}_prompt_text_embeddings.pt'\n",
|
247 |
" )"
|
248 |
-
]
|
249 |
-
"metadata": {
|
250 |
-
"collapsed": false
|
251 |
-
}
|
252 |
},
|
253 |
{
|
254 |
"cell_type": "code",
|
255 |
"execution_count": null,
|
256 |
-
"outputs": [],
|
257 |
-
"source": [],
|
258 |
"metadata": {
|
259 |
"collapsed": false
|
260 |
-
}
|
|
|
|
|
261 |
}
|
262 |
],
|
263 |
"metadata": {
|
@@ -269,14 +250,14 @@
|
|
269 |
"language_info": {
|
270 |
"codemirror_mode": {
|
271 |
"name": "ipython",
|
272 |
-
"version":
|
273 |
},
|
274 |
"file_extension": ".py",
|
275 |
"mimetype": "text/x-python",
|
276 |
"name": "python",
|
277 |
"nbconvert_exporter": "python",
|
278 |
-
"pygments_lexer": "
|
279 |
-
"version": "
|
280 |
}
|
281 |
},
|
282 |
"nbformat": 4,
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"metadata": {
|
7 |
"collapsed": true
|
8 |
},
|
|
|
39 |
" '{0} in a suit',\n",
|
40 |
" '{0} in a dress'\n",
|
41 |
"]\n",
|
42 |
+
"OPEN_CLIP_LAION400M_MODEL_NAMES = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']\n",
|
43 |
+
"OPEN_CLIP_LAION2B_MODEL_NAMES = [('ViT-B-32', 'laion2b_s34b_b79k') , ('ViT-L-14', 'laion2b_s32b_b82k')]\n",
|
44 |
+
"OPEN_AI_MODELS = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']\n",
|
45 |
"SEED = 42"
|
46 |
]
|
47 |
},
|
48 |
{
|
49 |
"cell_type": "code",
|
50 |
+
"execution_count": null,
|
51 |
+
"metadata": {
|
52 |
+
"collapsed": false
|
53 |
+
},
|
54 |
"outputs": [],
|
55 |
"source": [
|
56 |
+
"MODELS = {}\n",
|
57 |
+
"for model_name in OPEN_CLIP_LAION400M_MODEL_NAMES:\n",
|
58 |
+
" dataset = 'LAION400M'\n",
|
59 |
+
" model, _, preprocess = open_clip.create_model_and_transforms(\n",
|
60 |
+
" model_name,\n",
|
61 |
+
" pretrained=f'{dataset.lower()}_e32'\n",
|
62 |
+
" )\n",
|
63 |
" model = model.eval()\n",
|
64 |
+
" MODELS[(model_name, dataset.lower())] = {\n",
|
65 |
+
" 'model_instance': model,\n",
|
66 |
+
" 'preprocessing': preprocess,\n",
|
67 |
+
" 'model_name': model_name,\n",
|
68 |
+
" 'tokenizer': open_clip.get_tokenizer(model_name),\n",
|
69 |
+
" }\n",
|
70 |
+
"\n",
|
71 |
+
"for model_name, dataset_name in OPEN_CLIP_LAION2B_MODEL_NAMES:\n",
|
72 |
+
" dataset = 'LAION2B'\n",
|
73 |
+
" model, _, preprocess = open_clip.create_model_and_transforms(\n",
|
74 |
+
" model_name,\n",
|
75 |
+
" pretrained = dataset_name\n",
|
76 |
+
" )\n",
|
77 |
+
" model = model.eval()\n",
|
78 |
+
" MODELS[(model_name, dataset.lower())] = {\n",
|
79 |
+
" 'model_instance': model,\n",
|
80 |
+
" 'preprocessing': preprocess,\n",
|
81 |
+
" 'model_name': model_name,\n",
|
82 |
+
" 'tokenizer': open_clip.get_tokenizer(model_name)\n",
|
83 |
+
" }\n",
|
84 |
+
"\n",
|
85 |
+
"for model_name in OPEN_AI_MODELS:\n",
|
86 |
+
" dataset = 'OpenAI'\n",
|
87 |
+
" model, _, preprocess = open_clip.create_model_and_transforms(\n",
|
88 |
+
" model_name,\n",
|
89 |
+
" pretrained=dataset.lower()\n",
|
90 |
+
" )\n",
|
91 |
+
" model = model.eval()\n",
|
92 |
+
" MODELS[(model_name, dataset.lower())] = {\n",
|
93 |
+
" 'model_instance': model,\n",
|
94 |
+
" 'preprocessing': preprocess,\n",
|
95 |
+
" 'model_name': model_name,\n",
|
96 |
+
" 'tokenizer': open_clip.get_tokenizer(model_name)\n",
|
97 |
+
" }"
|
98 |
+
]
|
99 |
},
|
100 |
{
|
101 |
"cell_type": "code",
|
102 |
+
"execution_count": null,
|
103 |
+
"metadata": {
|
104 |
+
"collapsed": false
|
105 |
+
},
|
106 |
"outputs": [],
|
107 |
"source": [
|
108 |
"# define a function to get the predictions for an actor/actress\n",
|
|
|
127 |
" text_features = torch.cat(text_features).view(list(context.shape[:-1]) + [-1])\n",
|
128 |
"\n",
|
129 |
" return text_features"
|
130 |
+
]
|
|
|
|
|
|
|
131 |
},
|
132 |
{
|
133 |
"cell_type": "code",
|
134 |
+
"execution_count": null,
|
135 |
+
"metadata": {
|
136 |
+
"collapsed": false
|
137 |
+
},
|
138 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
"source": [
|
140 |
"# load the possible names\n",
|
141 |
"possible_names = pd.read_csv('./full_names.csv', index_col=0)\n",
|
142 |
"possible_names\n",
|
143 |
"# possible_names_list = (possible_names['first_name'] + ' ' + possible_names['last_name']).tolist()\n",
|
144 |
"# possible_names_list[:5]"
|
145 |
+
]
|
|
|
|
|
|
|
146 |
},
|
147 |
{
|
148 |
"cell_type": "code",
|
149 |
+
"execution_count": null,
|
150 |
+
"metadata": {
|
151 |
+
"collapsed": false
|
152 |
+
},
|
153 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
"source": [
|
155 |
"# populate the prompts with the possible names\n",
|
156 |
"prompts = []\n",
|
|
|
162 |
" prompts.append(df_dict)\n",
|
163 |
"prompts = pd.DataFrame(prompts)\n",
|
164 |
"prompts"
|
165 |
+
]
|
|
|
|
|
|
|
166 |
},
|
167 |
{
|
168 |
"cell_type": "code",
|
169 |
+
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
"metadata": {
|
171 |
"collapsed": false
|
172 |
+
},
|
173 |
+
"outputs": [],
|
174 |
+
"source": [
|
175 |
+
"label_context_vecs_per_model = {}\n",
|
176 |
+
"for dict_key, model_dict in MODELS.items():\n",
|
177 |
+
" label_context_vecs = []\n",
|
178 |
+
" for i in range(len(PROMPTS)):\n",
|
179 |
+
" context = model_dict['tokenizer'](prompts[f'prompt_{i}'].to_numpy())\n",
|
180 |
+
" label_context_vecs.append(context)\n",
|
181 |
+
" label_context_vecs = torch.stack(label_context_vecs)\n",
|
182 |
+
" label_context_vecs_per_model[dict_key] = label_context_vecs"
|
183 |
+
]
|
184 |
},
|
185 |
{
|
186 |
"cell_type": "code",
|
187 |
+
"execution_count": null,
|
188 |
+
"metadata": {
|
189 |
+
"collapsed": false
|
190 |
+
},
|
191 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
"source": [
|
|
|
|
|
193 |
"text_embeddings_per_model = {}\n",
|
194 |
+
"for dict_key, model_dict in MODELS.items():\n",
|
195 |
+
" label_context_vecs = label_context_vecs_per_model[dict_key].to(device)\n",
|
196 |
+
" model = model_dict['model_instance']\n",
|
197 |
" model = model.to(device)\n",
|
198 |
+
" text_embeddings = get_text_embeddings(model, label_context_vecs, use_tqdm=True, context_batchsize=5_000)\n",
|
199 |
+
" text_embeddings_per_model[dict_key] = text_embeddings\n",
|
200 |
" model = model.cpu()\n",
|
201 |
+
" label_context_vecs = label_context_vecs.cpu()\n",
|
202 |
"\n",
|
203 |
"label_context_vecs = label_context_vecs.cpu()"
|
204 |
+
]
|
|
|
|
|
|
|
205 |
},
|
206 |
{
|
207 |
"cell_type": "code",
|
208 |
+
"execution_count": null,
|
209 |
+
"metadata": {
|
210 |
+
"collapsed": false
|
211 |
+
},
|
212 |
"outputs": [],
|
213 |
"source": [
|
214 |
"# save the calculated embeddings to a file\n",
|
215 |
"if not os.path.exists('./prompt_text_embeddings'):\n",
|
216 |
" os.makedirs('./prompt_text_embeddings')"
|
217 |
+
]
|
|
|
|
|
|
|
218 |
},
|
219 |
{
|
220 |
"cell_type": "code",
|
221 |
+
"execution_count": null,
|
222 |
+
"metadata": {
|
223 |
+
"collapsed": false
|
224 |
+
},
|
225 |
"outputs": [],
|
226 |
"source": [
|
227 |
+
"for (model_name, dataset_name), model_dict in MODELS.items():\n",
|
228 |
" torch.save(\n",
|
229 |
+
" text_embeddings_per_model[(model_name, dataset_name)],\n",
|
230 |
+
" f'./prompt_text_embeddings/{model_name}_{dataset_name}_prompt_text_embeddings.pt'\n",
|
231 |
" )"
|
232 |
+
]
|
|
|
|
|
|
|
233 |
},
|
234 |
{
|
235 |
"cell_type": "code",
|
236 |
"execution_count": null,
|
|
|
|
|
237 |
"metadata": {
|
238 |
"collapsed": false
|
239 |
+
},
|
240 |
+
"outputs": [],
|
241 |
+
"source": []
|
242 |
}
|
243 |
],
|
244 |
"metadata": {
|
|
|
250 |
"language_info": {
|
251 |
"codemirror_mode": {
|
252 |
"name": "ipython",
|
253 |
+
"version": 3
|
254 |
},
|
255 |
"file_extension": ".py",
|
256 |
"mimetype": "text/x-python",
|
257 |
"name": "python",
|
258 |
"nbconvert_exporter": "python",
|
259 |
+
"pygments_lexer": "ipython3",
|
260 |
+
"version": "3.8.13"
|
261 |
}
|
262 |
},
|
263 |
"nbformat": 4,
|
prompt_text_embeddings/{ViT-B-16_prompt_text_embeddings.pt β ViT-B-16_laion400m_prompt_text_embeddings.pt}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2eba829d60be9ec74485ad0ccdc6cd93c599bb8c0ed3036c099a19ab71fa251a
|
3 |
+
size 430080977
|
prompt_text_embeddings/{ViT-B-32_prompt_text_embeddings.pt β ViT-B-16_openai_prompt_text_embeddings.pt}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c86948737c065233154162deb78c14bfd827eb731df087da082a64d2540f88b6
|
3 |
+
size 430080968
|
prompt_text_embeddings/{ViT-L-14_prompt_text_embeddings.pt β ViT-B-32_laion2b_prompt_text_embeddings.pt}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe5c5abd8f02ae34eb97ed192ec67e6345cb44df2c60c00bf71d1fe86d06f9d4
|
3 |
+
size 430080971
|
prompt_text_embeddings/ViT-B-32_laion400m_prompt_text_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da2108f9b80d59975ad30ee72405a31f1d722a00cf22d54fe3523784e6706151
|
3 |
+
size 430080977
|
prompt_text_embeddings/ViT-B-32_openai_prompt_text_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a576a46ca84d9794d4f4e82eed2146f104411462f364efce61274c585c5546c
|
3 |
+
size 430080968
|
prompt_text_embeddings/ViT-L-14_laion2b_prompt_text_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62b2313b873c93cf18c08faa46a1a7088cf7c832abdd28ddaedd0e46624c693d
|
3 |
+
size 645120971
|
prompt_text_embeddings/ViT-L-14_laion400m_prompt_text_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae5074916120756e59f8a38516f11a3a9c2c962843cff75de7947247a74c3ee6
|
3 |
+
size 645120977
|
prompt_text_embeddings/ViT-L-14_openai_prompt_text_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66f9c693c897997cc38160eacae8ff6547312f026ae680431833c7b0898a9a44
|
3 |
+
size 645120968
|