LLMs_for_Art_Commentary / imageprocessing /clip_object_recognition.py
LuisAVasquez's picture
adding basic image processing tools
4859d06 verified
raw
history blame
5.57 kB
###
# take a file containing image filepaths and return a file also containing detected objects
#
# the input csv file must contain an 'image_file' column containing all the image filepaths
# #
import os
import clip
import torch
import pandas as pd
from PIL import Image
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# this dataset gives us the object classes
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
def save_checkpoint(checkpoint_path,df, object_list):
output_df = df.copy()
output_df['clip_recognized_objects'] = object_list
output_df.to_csv(checkpoint_path,
index= False, # don't write a new 'Index' column
)
print("Saved checkpoint!")
def load_checkpoint(checkpoint_path):
try:
print("reading checkpoint at ", checkpoint_path)
df = pd.read_csv(checkpoint_path)
cached_objects = {
row['image_file']: row['clip_recognized_objects']
for _, row in df.iterrows()
}
print(f"Checkpoint loaded succesfully to cache: {len(cached_objects)} processed files")
return cached_objects
except:
print("Checkpoint was not loaded")
return cached_objects_dict
def get_checkpoint_path(output_path):
#checkpoint_path = "checkpoint" + os.path.basename(output_path)
#checkpoint_path = os.path.join( os.path.dirname(output_path), checkpoint_path)
#return checkpoint_path
return output_path
cached_objects_dict = {} # to avoid recomputing
def get_objects(filepath, model, preprocess, device, cached_objects_dict):
objects = cached_objects_dict.get(filepath)
if objects is None:
objects = get_objects_in_image(filepath, model, preprocess, device)
cached_objects_dict[filepath] = objects
return objects
def get_objects_in_image(image_filepath, model, preprocess, device):
# Prepare the inputs
image = Image.open(image_filepath).resize((600,600))
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Append the the result
#print("\nTop predictions:\n")
objects = []
for value, index in zip(values, indices):
objects.append((cifar100.classes[index], value.item()))
# print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
return objects
def clip_object_detection(input_csv, output_csv):
checkpoint_path = get_checkpoint_path(output_csv)
cached_objects_dict = load_checkpoint(checkpoint_path)
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
recognized_objects_per_image = []
processed_files = set(cached_objects_dict.keys())
df = pd.read_csv(input_csv)
iterable_list = list(enumerate( df['image_file']))
for elem in tqdm(iterable_list):
idx = elem[0]
filepath = elem[1]
#save checkpoint every 50 files
if (not (len(processed_files) % 49)
):
print(f"Images processed: {len(processed_files)}")
save_checkpoint(checkpoint_path, df.iloc[:idx], recognized_objects_per_image)
objects = get_objects(
filepath, model, preprocess, device,
cached_objects_dict
)
recognized_objects_per_image.append(objects)
processed_files.add(filepath)
recognized_objects_per_image = pd.Series(recognized_objects_per_image)
return recognized_objects_per_image
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="CLIP object recognition",
description='Recognizes the top 5 main objects per image in an image list')
parser.add_argument("--input_csv", "-in", metavar='in', type=str, nargs=1,
help='input file containing images-paths for object recognition.',
#default=[default_painting_folder]
)
parser.add_argument("--output_csv", "-out", metavar='out', type=str, nargs=1,
help='output file containing images-paths + recognized objects'
#default=[default_interpretation_folder]
)
args = parser.parse_args()
input_csv_file = args.input_csv[0]
output_csv_file = args.output_csv[0]
print(">>> input file: " , input_csv_file)
print(">>> output file: ", output_csv_file)
# perform object recognition
recognized_objects_per_image = clip_object_detection(input_csv_file, output_csv_file)
# add a column with the recognized objects
output_df = pd.read_csv(input_csv_file)
output_df['clip_recognized_objects'] = recognized_objects_per_image
output_df.to_csv(output_csv_file,
index= False, # don't write a new 'Index' column
)