BigGAN-deep-128 / pipeline.py
osanseviero's picture
osanseviero HF staff
Update pipeline.py
1931f62
raw
history blame contribute delete
No virus
1.27 kB
import torch
import nltk
import io
import base64
import shutil
from torchvision import transforms
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
class PreTrainedPipeline():
def __init__(self, path=""):
"""
Initialize model
"""
nltk.download('wordnet')
self.model = BigGAN.from_pretrained(path)
self.truncation = 0.1
def __call__(self, inputs: str):
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image` with the raw image representation as PIL.
"""
class_vector = one_hot_from_names([inputs], batch_size=1)
if type(class_vector) == type(None):
raise ValueError("Input is not in ImageNet")
noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)
with torch.no_grad():
output = self.model(noise_vector, class_vector, self.truncation)
# Scale image
img = output[0]
img = (img + 1) / 2.0
img = transforms.ToPILImage()(img)
return img