import torch import nltk import io import base64 import shutil from torchvision import transforms from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample class PreTrainedPipeline(): def __init__(self, path=""): """ Initialize model """ nltk.download('wordnet') self.model = BigGAN.from_pretrained(path) self.truncation = 0.1 def __call__(self, inputs: str): """ Args: inputs (:obj:`str`): a string containing some text Return: A :obj:`PIL.Image` with the raw image representation as PIL. """ class_vector = one_hot_from_names([inputs], batch_size=1) if type(class_vector) == type(None): raise ValueError("Input is not in ImageNet") noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) noise_vector = torch.from_numpy(noise_vector) class_vector = torch.from_numpy(class_vector) with torch.no_grad(): output = self.model(noise_vector, class_vector, self.truncation) # Scale image img = output[0] img = (img + 1) / 2.0 img = transforms.ToPILImage()(img) return img