BigGAN-deep-128 / pipeline.py
osanseviero's picture
Add requirements and inference
b7d7804
raw
history blame
1.11 kB
import torch
import nltk
from torchvision import transforms
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
class PreTrainedPipeline():
def __init__(self, path=""):
"""
Initialize model
"""
nltk.download('wordnet')
self.model = BigGAN.from_pretrained(path)
self.truncation = 0.1
def __call__(self, inputs: str):
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image`. The raw image representation as PIL.
"""
class_vector = one_hot_from_names([inputs], batch_size=1)
if class_vector == None:
raise ValueError("Input is not in ImageNet")
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1)
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)
with torch.no_grad():
output = model(noise_vector, class_vector, truncation)
return transforms.ToPILImage()(output[0])