person-thumbs-up / hf_dataset.py
Srimanth Agastyaraju
Initial commit
5372b88
raw
history blame
3.47 kB
import os
import json
import requests
import random
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm
import pandas as pd
def caption_images(image_paths, processor, model, folder):
image_captions_dict = []
for img_path in tqdm(image_paths):
pil_image = Image.open(img_path).convert('RGB')
image_name = img_path.split("/")[-1]
# unconditional image captioning
inputs = processor(pil_image, return_tensors="pt").to("cuda")
out = model.generate(**inputs)
out_caption = processor.decode(out[0], skip_special_tokens=True)
if folder=="images/" and "thumbs up" in out_caption:
out_caption = out_caption.replace("thumbs up", "#thumbsup")
elif folder=="images/":
th_choice = random.choice([True, False])
out_caption = "#thumbsup " + out_caption if th_choice else out_caption + " #thumbsup"
elif folder=="tom_cruise_dataset/":
if "man" in out_caption:
out_caption = out_caption.replace("man", "<tom_cruise>")
elif "person" in out_caption:
out_caption = out_caption.replace("person", "<tom_cruise>")
else:
out_caption = "<tom_cruise> " + out_caption
# For some reason, the model puts the word "arafed" for a human
if "arafed" in out_caption:
out_caption = out_caption.replace("arafed ", "")
image_captions_dict.append({"file_name": folder+image_name, "text": out_caption})
return image_captions_dict
def create_thumbs_up_person_dataset(path, cache_dir="/l/vision/v5/sragas/hf_models/"):
random.seed(15)
image_captions_dict = []
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
cache_dir=cache_dir,
torch_dtype=torch.float32).to("cuda")
# Caption the thumbs up images for prompts
image_paths = [path + "images/" + file for file in os.listdir(path+"images/")]
# Read from the person dataset
person_paths = [path + "tom_cruise_dataset/" + file for file in sorted(os.listdir(path+"tom_cruise_dataset/"))]
# If person is sachin, prompts are the filenames, use the below code.
# person_filenames = [filename.split("/")[-1] for filename in person_paths]
# person_captions = [filename.split(".")[0].replace("1", "<sachin>") for filename in person_filenames]
# persons_dict = [{"file_name": "sachin_dataset/" + filename, "text": caption} for filename, caption in zip(person_filenames, person_captions)]
# image_captions_dict.extend(persons_dict)
image_captions_dict.extend(caption_images(person_paths, processor, model, "tom_cruise_dataset/"))
image_captions_dict.extend(caption_images(image_paths, processor, model, "images/"))
# with open(f"{path}metadata.jsonl", 'w') as fp:
# json.dump(image_captions_dict, fp)
image_captions_dict = pd.DataFrame(image_captions_dict)
image_captions_dict.to_csv(f"{path}metadata.csv", index=False)
image_captions_dict.to_csv(f"metadata.csv", index=False)
if __name__ == "__main__":
images_dir = "/l/vision/v5/sragas/easel_ai/thumbs_up_dataset/"
create_thumbs_up_person_dataset(images_dir)