import torch | |
import nltk | |
from torchvision import transforms | |
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample | |
class PreTrainedPipeline(): | |
def __init__(self, path=""): | |
""" | |
Initialize model | |
""" | |
nltk.download('wordnet') | |
self.model = BigGAN.from_pretrained(path) | |
self.truncation = 0.1 | |
def __call__(self, inputs: str): | |
""" | |
Args: | |
inputs (:obj:`str`): | |
a string containing some text | |
Return: | |
A :obj:`PIL.Image`. The raw image representation as PIL. | |
""" | |
class_vector = one_hot_from_names([inputs], batch_size=1) | |
if class_vector == None: | |
raise ValueError("Input is not in ImageNet") | |
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1) | |
noise_vector = torch.from_numpy(noise_vector) | |
class_vector = torch.from_numpy(class_vector) | |
with torch.no_grad(): | |
output = model(noise_vector, class_vector, truncation) | |
return transforms.ToPILImage()(output[0]) | |