## Import

In [15]:
import torch
import numpy as np
from fairseq import utils, tasks
from fairseq import checkpoint_utils
from utils.eval_utils import eval_step
from tasks.mm_tasks import ImageGenTask
from models.unival import UnIVALModel
from PIL import Image
from torchvision import transforms
import time


# turn on cuda if GPU is available
use_cuda = torch.cuda.is_available()
# use fp16 only when GPU is available
use_fp16 = True if use_cuda else False

In [16]:
# Register caption task
tasks.register_task('image_gen', ImageGenTask)


.register_task_cls(cls)>

### Load model 

In [12]:
# Load pretrained ckpt & config
clip_model_path='/data/mshukor/data/ofa/clip/ViT-B-16.pt'
vqgan_model_path='/data/mshukor/data/ofa/vqgan/last.ckpt'
vqgan_config_path='/data/mshukor/data/ofa/vqgan/model.yaml'

# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofa_stage_1_base_s2_hsep1_long/checkpoint_best.pt'
# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_long/checkpoint_best.pt'
# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_base_best.pt'
# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_large_best.pt'

# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_hsep1_long/checkpoint_best.pt'
checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_2_base_s2_hsep1_long/checkpoint_best.pt'



video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'
resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'

gen_images_path='results/image_gen/'

overrides = {"bpe_dir": "utils/BPE",
 "eval_cider": False,
 "beam": 24,
 "max_len_b": 1024,
 "max_len_a": 0,
 "min_len": 1024,
 "sampling_topk": 256,
 "constraint_range": "50265,58457",
 "clip_model_path": clip_model_path,
 "vqgan_model_path": vqgan_model_path,
 "vqgan_config_path": vqgan_config_path,
 "seed": 42,
 "video_model_path": video_model_path, 
 "resnet_model_path": resnet_model_path,
 "gen_images_path":gen_images_path,
 "patch_image_size": 256,
 "temperature": 1.5,
 }

models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
 utils.split_paths(checkpoint_path),
 arg_overrides=overrides
)

task.cfg.sampling_times = 2
# Move models to GPU
for model in models:
 model.eval()
 if use_fp16:
 model.half()
 if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
 model.cuda()
 model.prepare_for_inference_(cfg)

# Initialize generator
generator = task.build_generator(models, cfg.generation)

# Text preprocess
bos_item = torch.LongTensor([task.src_dict.bos()])
eos_item = torch.LongTensor([task.src_dict.eos()])
pad_idx = task.src_dict.pad()

self.sample_patch_num 784
self.sample_audio_patch_num None
self.sample_video_patch_num None
self.with_cls False
Frozen image bn 
Loading: all_resnext101
use bn: 
load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth
_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])
load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth

RAM memory % used: 10.5
RAM Used (GB): 19.574349824
encoder
RAM memory % used: 10.5
decoder
RAM memory % used: 10.5
ofa
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.


### Preprocess

In [13]:
def encode_text(text, length=None, append_bos=False, append_eos=False):
 s = task.tgt_dict.encode_line(
 line=task.bpe.encode(text),
 add_if_not_exist=False,
 append_eos=False
 ).long()
 if length is not None:
 s = s[:length]
 if append_bos:
 s = torch.cat([bos_item, s])
 if append_eos:
 s = torch.cat([s, eos_item])
 return s


# Construct input for image generation task
def construct_sample(query: str):
 code_mask = torch.tensor([True])
 src_text = encode_text(" what is the complete image? caption: {}".format(query), append_bos=True,
 append_eos=True).unsqueeze(0)
 src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
 sample = {
 "id": np.array(['42']),
 "net_input": {
 "src_tokens": src_text,
 "src_lengths": src_length,
 "code_masks": code_mask
 }
 }
 return sample


# Function to turn FP32 to FP16
def apply_half(t):
 if t.dtype is torch.float32:
 return t.to(dtype=torch.half)
 return t


# Function for image generation
def image_generation(caption):
 sample = construct_sample(caption)
 sample = utils.move_to_cuda(sample) if use_cuda else sample
 sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
 print('|Start|', time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), caption)
 with torch.no_grad():
 result, scores = eval_step(task, generator, models, sample)

 # return top-4 results (ranked by clip)
 images = [result[i]['image'] for i in range(4)]
 pic_size = 256
 retImage = Image.new('RGB', (pic_size * 2, pic_size * 2))
 print('|FINISHED|', time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), caption)
 for i in range(4):
 loc = ((i % 2) * pic_size, int(i / 2) * pic_size)
 retImage.paste(images[i], loc)
 return retImage

### Inference

In [14]:
query = "A brown horse in the street"
# query = "Cattle grazing on grass near a lake surrounded by mountain."
# query = 'A street scene with a double-decker bus on the road.'
# query = 'A path.'


retImage = image_generation(query)


|Start| 2023-06-29 12:57:39 A brown horse in the street
|FINISHED| 2023-06-29 12:59:03 A brown horse in the street


In [None]:
retImage.save(f'{query}.png')