import torch import nltk import io import base64 from torchvision import transforms from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample class PreTrainedPipeline(): def __init__(self, path=""): """ Initialize model """ nltk.download('wordnet') self.model = BigGAN.from_pretrained(path) self.truncation = 0.1 def __call__(self, inputs: str): """ Args: inputs (:obj:`str`): a string containing some text Return: A :obj:`PIL.Image`. The raw image representation as PIL. """ class_vector = one_hot_from_names([inputs], batch_size=1) if type(class_vector) == type(None): raise ValueError("Input is not in ImageNet") noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) noise_vector = torch.from_numpy(noise_vector) class_vector = torch.from_numpy(class_vector) with torch.no_grad(): output = self.model(noise_vector, class_vector, self.truncation) img = transforms.ToPILImage()(output[0]) buffer = io.BytesIO() img.save(buffer, format="JPEG") img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') return img_str