Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append('../../') | |
import argparse | |
import base64 | |
from io import BytesIO | |
from data.file_dataset import FileDataset | |
from PIL import Image, ImageFile | |
from torchvision import transforms | |
from omegaconf import OmegaConf | |
from models.taming.models.vqgan import GumbelVQ | |
import os | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
ImageFile.MAX_IMAGE_PIXELS = None | |
Image.MAX_IMAGE_PIXELS = None | |
class VQGANDataset(Dataset): | |
def __init__(self, file, selected_cols): | |
self.reader = FileDataset( | |
file, | |
selected_col_ids=selected_cols, | |
) | |
self.code_resize_transform = transforms.Compose([ | |
lambda image: image.convert("RGB"), | |
transforms.Resize(args.code_image_size, interpolation=Image.LANCZOS), | |
transforms.ToTensor(), | |
preprocess_vqgan | |
]) | |
def __len__(self): | |
return len(self.reader) | |
def __getitem__(self, item): | |
column_l = self.reader[item] | |
if len(column_l) == 4: | |
pair_id, image_id, image, text = column_l | |
elif len(column_l) == 2: | |
image_id, image = column_l | |
else: | |
raise NotImplementedError | |
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))) | |
code_image = self.code_resize_transform(image) | |
if len(column_l) == 4: | |
return {"code_image": code_image, "pair_id": pair_id, "image_id": image_id, "text": text} | |
elif len(column_l) == 2: | |
return {"code_image": code_image, "image_id": image_id} | |
def custom_to_pil(x): | |
x = x.detach().cpu() | |
x = torch.clamp(x, -1., 1.) | |
x = (x + 1.) / 2. | |
x = x.permute(1, 2, 0).numpy() | |
x = (255 * x).astype(np.uint8) | |
x = Image.fromarray(x) | |
if not x.mode == "RGB": | |
x = x.convert("RGB") | |
return x | |
def map_pixels(x, eps=0.1): | |
return (1 - 2 * eps) * x + eps | |
def preprocess_vqgan(x): | |
x = 2. * x - 1. | |
return x | |
def image_to_base64(img, format): | |
output_buffer = BytesIO() | |
img.save(output_buffer, format=format) | |
byte_data = output_buffer.getvalue() | |
base64_str = base64.b64encode(byte_data) | |
base64_str = str(base64_str, encoding='utf-8') | |
return base64_str | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--file", type=str, default="") | |
parser.add_argument("--outputs", type=str, default="") | |
parser.add_argument("--selected_cols", type=str, required=True) | |
parser.add_argument("--code_image_size", type=int, required=True) | |
parser.add_argument("--vq_model", type=str, required=True) | |
parser.add_argument("--vqgan_model_path", type=str, default=None) | |
parser.add_argument("--vqgan_config_path", type=str, default=None) | |
parser.add_argument("--log_interval", default=100, type=int, help="log interval") | |
parser.add_argument("--worker_cnt", type=int, default=1) | |
parser.add_argument("--local_rank", type=int, default=0) | |
parser.add_argument("--batch_size", type=int, default=32) | |
args = parser.parse_args() | |
vqgan_config = OmegaConf.load(args.vqgan_config_path) | |
vqgan = GumbelVQ(**vqgan_config.model.params) | |
sd = torch.load(args.vqgan_model_path, map_location="cpu")["state_dict"] | |
missing, unexpected = vqgan.load_state_dict(sd, strict=False) | |
for k, v in vqgan.named_parameters(): | |
v.requires_grad = False | |
image_tokenizer = vqgan.cuda().eval() | |
writer = open(args.outputs, 'w') | |
print("begin process") | |
data_cnt = 0 | |
dataset = VQGANDataset(args.file, args.selected_cols) | |
dataloader = DataLoader(dataset, batch_size=args.batch_size) | |
for data in dataloader: | |
batch_size = data["code_image"].size()[0] | |
with torch.no_grad(): | |
z, _, [_, _, image_codes] = image_tokenizer.encode(data["code_image"].cuda()) | |
image_codes = image_codes.view(batch_size, -1).detach() | |
for i, image_code in enumerate(image_codes): | |
code = ' '.join([str(num) for num in image_code.tolist()]) | |
if len(data.keys()) == 4: | |
writer.write('\t'.join([data['pair_id'][i], data['image_id'][i], data['text'][i], code])+'\n') | |
elif len(data.keys()) == 2: | |
writer.write('\t'.join([data['image_id'][i], code])+'\n') | |
else: | |
raise NotImplementedError | |
writer.close() | |
print("finish") | |