|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import clip |
|
import torch |
|
|
|
import pandas as pd |
|
from PIL import Image |
|
from torchvision.datasets import CIFAR100 |
|
from tqdm import tqdm |
|
|
|
|
|
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) |
|
|
|
def save_checkpoint(checkpoint_path,df, object_list): |
|
output_df = df.copy() |
|
output_df['clip_recognized_objects'] = object_list |
|
output_df.to_csv(checkpoint_path, |
|
index= False, |
|
) |
|
print("Saved checkpoint!") |
|
|
|
def load_checkpoint(checkpoint_path): |
|
try: |
|
print("reading checkpoint at ", checkpoint_path) |
|
df = pd.read_csv(checkpoint_path) |
|
|
|
cached_objects = { |
|
row['image_file']: row['clip_recognized_objects'] |
|
for _, row in df.iterrows() |
|
} |
|
print(f"Checkpoint loaded succesfully to cache: {len(cached_objects)} processed files") |
|
return cached_objects |
|
except: |
|
print("Checkpoint was not loaded") |
|
return cached_objects_dict |
|
|
|
def get_checkpoint_path(output_path): |
|
|
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
cached_objects_dict = {} |
|
|
|
def get_objects(filepath, model, preprocess, device, cached_objects_dict): |
|
objects = cached_objects_dict.get(filepath) |
|
if objects is None: |
|
objects = get_objects_in_image(filepath, model, preprocess, device) |
|
cached_objects_dict[filepath] = objects |
|
return objects |
|
|
|
def get_objects_in_image(image_filepath, model, preprocess, device): |
|
|
|
|
|
|
|
image = Image.open(image_filepath).resize((600,600)) |
|
image_input = preprocess(image).unsqueeze(0).to(device) |
|
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
image_features = model.encode_image(image_input) |
|
text_features = model.encode_text(text_inputs) |
|
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
|
values, indices = similarity[0].topk(5) |
|
|
|
|
|
|
|
|
|
objects = [] |
|
for value, index in zip(values, indices): |
|
objects.append((cifar100.classes[index], value.item())) |
|
|
|
return objects |
|
|
|
|
|
|
|
def clip_object_detection(input_csv, output_csv): |
|
|
|
checkpoint_path = get_checkpoint_path(output_csv) |
|
cached_objects_dict = load_checkpoint(checkpoint_path) |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load('ViT-B/32', device) |
|
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) |
|
|
|
recognized_objects_per_image = [] |
|
processed_files = set(cached_objects_dict.keys()) |
|
|
|
df = pd.read_csv(input_csv) |
|
|
|
iterable_list = list(enumerate( df['image_file'])) |
|
for elem in tqdm(iterable_list): |
|
idx = elem[0] |
|
filepath = elem[1] |
|
|
|
|
|
if (not (len(processed_files) % 49) |
|
): |
|
print(f"Images processed: {len(processed_files)}") |
|
save_checkpoint(checkpoint_path, df.iloc[:idx], recognized_objects_per_image) |
|
|
|
objects = get_objects( |
|
filepath, model, preprocess, device, |
|
cached_objects_dict |
|
) |
|
recognized_objects_per_image.append(objects) |
|
processed_files.add(filepath) |
|
|
|
recognized_objects_per_image = pd.Series(recognized_objects_per_image) |
|
|
|
return recognized_objects_per_image |
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(prog="CLIP object recognition", |
|
description='Recognizes the top 5 main objects per image in an image list') |
|
|
|
parser.add_argument("--input_csv", "-in", metavar='in', type=str, nargs=1, |
|
help='input file containing images-paths for object recognition.', |
|
|
|
) |
|
parser.add_argument("--output_csv", "-out", metavar='out', type=str, nargs=1, |
|
help='output file containing images-paths + recognized objects' |
|
|
|
) |
|
args = parser.parse_args() |
|
input_csv_file = args.input_csv[0] |
|
output_csv_file = args.output_csv[0] |
|
|
|
print(">>> input file: " , input_csv_file) |
|
print(">>> output file: ", output_csv_file) |
|
|
|
|
|
|
|
recognized_objects_per_image = clip_object_detection(input_csv_file, output_csv_file) |
|
|
|
|
|
output_df = pd.read_csv(input_csv_file) |
|
output_df['clip_recognized_objects'] = recognized_objects_per_image |
|
output_df.to_csv(output_csv_file, |
|
index= False, |
|
) |
|
|
|
|
|
|
|
|