|
import os |
|
import numpy as np |
|
import torch |
|
import clip |
|
import csv |
|
import tqdm |
|
from profanity_filter import ProfanityFilter |
|
|
|
|
|
templates = [ |
|
lambda c: f'a bad photo of a {c}.', |
|
lambda c: f'a photo of many {c}.', |
|
lambda c: f'a sculpture of a {c}.', |
|
lambda c: f'a photo of the hard to see {c}.', |
|
lambda c: f'a low resolution photo of the {c}.', |
|
lambda c: f'a rendering of a {c}.', |
|
lambda c: f'graffiti of a {c}.', |
|
lambda c: f'a bad photo of the {c}.', |
|
lambda c: f'a cropped photo of the {c}.', |
|
lambda c: f'a tattoo of a {c}.', |
|
lambda c: f'the embroidered {c}.', |
|
lambda c: f'a photo of a hard to see {c}.', |
|
lambda c: f'a bright photo of a {c}.', |
|
lambda c: f'a photo of a clean {c}.', |
|
lambda c: f'a photo of a dirty {c}.', |
|
lambda c: f'a dark photo of the {c}.', |
|
lambda c: f'a drawing of a {c}.', |
|
lambda c: f'a photo of my {c}.', |
|
lambda c: f'the plastic {c}.', |
|
lambda c: f'a photo of the cool {c}.', |
|
lambda c: f'a close-up photo of a {c}.', |
|
lambda c: f'a black and white photo of the {c}.', |
|
lambda c: f'a painting of the {c}.', |
|
lambda c: f'a painting of a {c}.', |
|
lambda c: f'a pixelated photo of the {c}.', |
|
lambda c: f'a sculpture of the {c}.', |
|
lambda c: f'a bright photo of the {c}.', |
|
lambda c: f'a cropped photo of a {c}.', |
|
lambda c: f'a plastic {c}.', |
|
lambda c: f'a photo of the dirty {c}.', |
|
lambda c: f'a jpeg corrupted photo of a {c}.', |
|
lambda c: f'a blurry photo of the {c}.', |
|
lambda c: f'a photo of the {c}.', |
|
lambda c: f'a good photo of the {c}.', |
|
lambda c: f'a rendering of the {c}.', |
|
lambda c: f'a {c} in a video game.', |
|
lambda c: f'a photo of one {c}.', |
|
lambda c: f'a doodle of a {c}.', |
|
lambda c: f'a close-up photo of the {c}.', |
|
lambda c: f'a photo of a {c}.', |
|
lambda c: f'the origami {c}.', |
|
lambda c: f'the {c} in a video game.', |
|
lambda c: f'a sketch of a {c}.', |
|
lambda c: f'a doodle of the {c}.', |
|
lambda c: f'a origami {c}.', |
|
lambda c: f'a low resolution photo of a {c}.', |
|
lambda c: f'the toy {c}.', |
|
lambda c: f'a rendition of the {c}.', |
|
lambda c: f'a photo of the clean {c}.', |
|
lambda c: f'a photo of a large {c}.', |
|
lambda c: f'a rendition of a {c}.', |
|
lambda c: f'a photo of a nice {c}.', |
|
lambda c: f'a photo of a weird {c}.', |
|
lambda c: f'a blurry photo of a {c}.', |
|
lambda c: f'a cartoon {c}.', |
|
lambda c: f'art of a {c}.', |
|
lambda c: f'a sketch of the {c}.', |
|
lambda c: f'a embroidered {c}.', |
|
lambda c: f'a pixelated photo of a {c}.', |
|
lambda c: f'itap of the {c}.', |
|
lambda c: f'a jpeg corrupted photo of the {c}.', |
|
lambda c: f'a good photo of a {c}.', |
|
lambda c: f'a plushie {c}.', |
|
lambda c: f'a photo of the nice {c}.', |
|
lambda c: f'a photo of the small {c}.', |
|
lambda c: f'a photo of the weird {c}.', |
|
lambda c: f'the cartoon {c}.', |
|
lambda c: f'art of the {c}.', |
|
lambda c: f'a drawing of the {c}.', |
|
lambda c: f'a photo of the large {c}.', |
|
lambda c: f'a black and white photo of a {c}.', |
|
lambda c: f'the plushie {c}.', |
|
lambda c: f'a dark photo of a {c}.', |
|
lambda c: f'itap of a {c}.', |
|
lambda c: f'graffiti of the {c}.', |
|
lambda c: f'a toy {c}.', |
|
lambda c: f'itap of my {c}.', |
|
lambda c: f'a photo of a cool {c}.', |
|
lambda c: f'a photo of a small {c}.', |
|
lambda c: f'a tattoo of the {c}.', |
|
] |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device) |
|
|
|
''' |
|
csv_data = open('openimage-classnames.csv') |
|
csv_reader = csv.reader(csv_data) |
|
class_names = [] |
|
for row in csv_reader: |
|
class_names.append(row[-1]) |
|
''' |
|
''' |
|
txt_data = open('tencent-ml-images.txt') |
|
pf = ProfanityFilter() |
|
lines = txt_data.readlines() |
|
class_names = [] |
|
for line in lines[4:]: |
|
class_name_precook = line.strip().split('\t')[-1] |
|
safe_list = '' |
|
for class_name in class_name_precook.split(', '): |
|
if pf.is_clean(class_name): |
|
safe_list += '%s, ' % class_name |
|
safe_list = safe_list[:-2] |
|
if len(safe_list) > 0: |
|
class_names.append(safe_list) |
|
f_w = open('tencent-ml-classnames.txt', 'w') |
|
for cln in class_names: |
|
f_w.write('%s\n' % cln) |
|
f_w.close() |
|
''' |
|
place_categories = np.loadtxt('categories_places365.txt', dtype=str) |
|
place_texts = [] |
|
for place in place_categories[:, 0]: |
|
place = place.split('/')[2:] |
|
if len(place) > 1: |
|
place = place[1] + ' ' + place[0] |
|
else: |
|
place = place[0] |
|
place = place.replace('_', ' ') |
|
place_texts.append(place) |
|
class_names = place_texts |
|
f_w = open('place365-classnames.txt', 'w') |
|
for cln in class_names: |
|
f_w.write('%s\n' % cln) |
|
f_w.close() |
|
print(class_names) |
|
|
|
class_weights = [] |
|
with torch.no_grad(): |
|
for classname in tqdm.tqdm(class_names, desc='encoding text'): |
|
texts = [template(classname) for template in templates] |
|
text_inputs = clip.tokenize(texts).to(device) |
|
text_features = clip_model.encode_text(text_inputs) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
text_features = text_features.mean(dim=0) |
|
text_features /= text_features.norm() |
|
class_weights.append(text_features) |
|
|
|
class_weights = torch.stack(class_weights) |
|
print(class_weights.shape) |
|
|
|
torch.save(class_weights, 'clip_ViTL14_place365_classifier_weights.pt') |
|
|