|
import torch |
|
import nltk |
|
import io |
|
import base64 |
|
import shutil |
|
from torchvision import transforms |
|
|
|
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
""" |
|
Initialize model |
|
""" |
|
nltk.download('wordnet') |
|
self.model = BigGAN.from_pretrained(path) |
|
self.truncation = 0.1 |
|
|
|
def __call__(self, inputs: str): |
|
""" |
|
Args: |
|
inputs (:obj:`str`): |
|
a string containing some text |
|
Return: |
|
A :obj:`PIL.Image` with the raw image representation as PIL. |
|
""" |
|
class_vector = one_hot_from_names([inputs], batch_size=1) |
|
if type(class_vector) == type(None): |
|
raise ValueError("Input is not in ImageNet") |
|
noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) |
|
noise_vector = torch.from_numpy(noise_vector) |
|
class_vector = torch.from_numpy(class_vector) |
|
with torch.no_grad(): |
|
output = self.model(noise_vector, class_vector, self.truncation) |
|
|
|
|
|
img = output[0] |
|
img = (img + 1) / 2.0 |
|
img = transforms.ToPILImage()(img) |
|
return img |