import os import numpy as np import torch import clip import csv import tqdm from profanity_filter import ProfanityFilter templates = [ lambda c: f'a bad photo of a {c}.', lambda c: f'a photo of many {c}.', lambda c: f'a sculpture of a {c}.', lambda c: f'a photo of the hard to see {c}.', lambda c: f'a low resolution photo of the {c}.', lambda c: f'a rendering of a {c}.', lambda c: f'graffiti of a {c}.', lambda c: f'a bad photo of the {c}.', lambda c: f'a cropped photo of the {c}.', lambda c: f'a tattoo of a {c}.', lambda c: f'the embroidered {c}.', lambda c: f'a photo of a hard to see {c}.', lambda c: f'a bright photo of a {c}.', lambda c: f'a photo of a clean {c}.', lambda c: f'a photo of a dirty {c}.', lambda c: f'a dark photo of the {c}.', lambda c: f'a drawing of a {c}.', lambda c: f'a photo of my {c}.', lambda c: f'the plastic {c}.', lambda c: f'a photo of the cool {c}.', lambda c: f'a close-up photo of a {c}.', lambda c: f'a black and white photo of the {c}.', lambda c: f'a painting of the {c}.', lambda c: f'a painting of a {c}.', lambda c: f'a pixelated photo of the {c}.', lambda c: f'a sculpture of the {c}.', lambda c: f'a bright photo of the {c}.', lambda c: f'a cropped photo of a {c}.', lambda c: f'a plastic {c}.', lambda c: f'a photo of the dirty {c}.', lambda c: f'a jpeg corrupted photo of a {c}.', lambda c: f'a blurry photo of the {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a good photo of the {c}.', lambda c: f'a rendering of the {c}.', lambda c: f'a {c} in a video game.', lambda c: f'a photo of one {c}.', lambda c: f'a doodle of a {c}.', lambda c: f'a close-up photo of the {c}.', lambda c: f'a photo of a {c}.', lambda c: f'the origami {c}.', lambda c: f'the {c} in a video game.', lambda c: f'a sketch of a {c}.', lambda c: f'a doodle of the {c}.', lambda c: f'a origami {c}.', lambda c: f'a low resolution photo of a {c}.', lambda c: f'the toy {c}.', lambda c: f'a rendition of the {c}.', lambda c: f'a photo of the clean {c}.', lambda c: f'a photo of a large {c}.', lambda c: f'a rendition of a {c}.', lambda c: f'a photo of a nice {c}.', lambda c: f'a photo of a weird {c}.', lambda c: f'a blurry photo of a {c}.', lambda c: f'a cartoon {c}.', lambda c: f'art of a {c}.', lambda c: f'a sketch of the {c}.', lambda c: f'a embroidered {c}.', lambda c: f'a pixelated photo of a {c}.', lambda c: f'itap of the {c}.', lambda c: f'a jpeg corrupted photo of the {c}.', lambda c: f'a good photo of a {c}.', lambda c: f'a plushie {c}.', lambda c: f'a photo of the nice {c}.', lambda c: f'a photo of the small {c}.', lambda c: f'a photo of the weird {c}.', lambda c: f'the cartoon {c}.', lambda c: f'art of the {c}.', lambda c: f'a drawing of the {c}.', lambda c: f'a photo of the large {c}.', lambda c: f'a black and white photo of a {c}.', lambda c: f'the plushie {c}.', lambda c: f'a dark photo of a {c}.', lambda c: f'itap of a {c}.', lambda c: f'graffiti of the {c}.', lambda c: f'a toy {c}.', lambda c: f'itap of my {c}.', lambda c: f'a photo of a cool {c}.', lambda c: f'a photo of a small {c}.', lambda c: f'a tattoo of the {c}.', ] os.environ['CUDA_VISIBLE_DEVICES'] = '0' device = "cuda" if torch.cuda.is_available() else "cpu" clip_model, clip_preprocess = clip.load("ViT-L/14", device=device) ''' csv_data = open('openimage-classnames.csv') csv_reader = csv.reader(csv_data) class_names = [] for row in csv_reader: class_names.append(row[-1]) ''' ''' txt_data = open('tencent-ml-images.txt') pf = ProfanityFilter() lines = txt_data.readlines() class_names = [] for line in lines[4:]: class_name_precook = line.strip().split('\t')[-1] safe_list = '' for class_name in class_name_precook.split(', '): if pf.is_clean(class_name): safe_list += '%s, ' % class_name safe_list = safe_list[:-2] if len(safe_list) > 0: class_names.append(safe_list) f_w = open('tencent-ml-classnames.txt', 'w') for cln in class_names: f_w.write('%s\n' % cln) f_w.close() ''' place_categories = np.loadtxt('categories_places365.txt', dtype=str) place_texts = [] for place in place_categories[:, 0]: place = place.split('/')[2:] if len(place) > 1: place = place[1] + ' ' + place[0] else: place = place[0] place = place.replace('_', ' ') place_texts.append(place) class_names = place_texts f_w = open('place365-classnames.txt', 'w') for cln in class_names: f_w.write('%s\n' % cln) f_w.close() print(class_names) class_weights = [] with torch.no_grad(): for classname in tqdm.tqdm(class_names, desc='encoding text'): texts = [template(classname) for template in templates] text_inputs = clip.tokenize(texts).to(device) text_features = clip_model.encode_text(text_inputs) text_features /= text_features.norm(dim=-1, keepdim=True) text_features = text_features.mean(dim=0) text_features /= text_features.norm() class_weights.append(text_features) class_weights = torch.stack(class_weights) print(class_weights.shape) #torch.save(class_weights, 'clip_ViTL14_openimage_classifier_weights.pt') torch.save(class_weights, 'clip_ViTL14_place365_classifier_weights.pt')