Spaces:
Sleeping
Sleeping
import os | |
import json | |
import requests | |
import random | |
from PIL import Image | |
import torch | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from tqdm import tqdm | |
import pandas as pd | |
def caption_images(image_paths, processor, model, folder): | |
image_captions_dict = [] | |
for img_path in tqdm(image_paths): | |
pil_image = Image.open(img_path).convert('RGB') | |
image_name = img_path.split("/")[-1] | |
# unconditional image captioning | |
inputs = processor(pil_image, return_tensors="pt").to("cuda") | |
out = model.generate(**inputs) | |
out_caption = processor.decode(out[0], skip_special_tokens=True) | |
if folder=="images/" and "thumbs up" in out_caption: | |
out_caption = out_caption.replace("thumbs up", "#thumbsup") | |
elif folder=="images/": | |
th_choice = random.choice([True, False]) | |
out_caption = "#thumbsup " + out_caption if th_choice else out_caption + " #thumbsup" | |
elif folder=="tom_cruise_dataset/": | |
if "man" in out_caption: | |
out_caption = out_caption.replace("man", "<tom_cruise>") | |
elif "person" in out_caption: | |
out_caption = out_caption.replace("person", "<tom_cruise>") | |
else: | |
out_caption = "<tom_cruise> " + out_caption | |
# For some reason, the model puts the word "arafed" for a human | |
if "arafed" in out_caption: | |
out_caption = out_caption.replace("arafed ", "") | |
image_captions_dict.append({"file_name": folder+image_name, "text": out_caption}) | |
return image_captions_dict | |
def create_thumbs_up_person_dataset(path, cache_dir="/l/vision/v5/sragas/hf_models/"): | |
random.seed(15) | |
image_captions_dict = [] | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", | |
cache_dir=cache_dir, | |
torch_dtype=torch.float32).to("cuda") | |
# Caption the thumbs up images for prompts | |
image_paths = [path + "images/" + file for file in os.listdir(path+"images/")] | |
# Read from the person dataset | |
person_paths = [path + "tom_cruise_dataset/" + file for file in sorted(os.listdir(path+"tom_cruise_dataset/"))] | |
# If person is sachin, prompts are the filenames, use the below code. | |
# person_filenames = [filename.split("/")[-1] for filename in person_paths] | |
# person_captions = [filename.split(".")[0].replace("1", "<sachin>") for filename in person_filenames] | |
# persons_dict = [{"file_name": "sachin_dataset/" + filename, "text": caption} for filename, caption in zip(person_filenames, person_captions)] | |
# image_captions_dict.extend(persons_dict) | |
image_captions_dict.extend(caption_images(person_paths, processor, model, "tom_cruise_dataset/")) | |
image_captions_dict.extend(caption_images(image_paths, processor, model, "images/")) | |
# with open(f"{path}metadata.jsonl", 'w') as fp: | |
# json.dump(image_captions_dict, fp) | |
image_captions_dict = pd.DataFrame(image_captions_dict) | |
image_captions_dict.to_csv(f"{path}metadata.csv", index=False) | |
image_captions_dict.to_csv(f"metadata.csv", index=False) | |
if __name__ == "__main__": | |
images_dir = "/l/vision/v5/sragas/easel_ai/thumbs_up_dataset/" | |
create_thumbs_up_person_dataset(images_dir) | |