DJStomp commited on
Commit
6513aae
·
1 Parent(s): e25ecb9

Create pipeline.py

Browse files

Copy from templates/text-to-image

Files changed (1) hide show
  1. pipeline.py +35 -0
pipeline.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nltk
3
+ import io
4
+ import base64
5
+ from torchvision import transforms
6
+ from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
7
+ class PreTrainedPipeline():
8
+ def __init__(self, path=""):
9
+ """
10
+ Initialize model
11
+ """
12
+ nltk.download('wordnet')
13
+ self.model = BigGAN.from_pretrained(path)
14
+ self.truncation = 0.1
15
+ def __call__(self, inputs: str):
16
+ """
17
+ Args:
18
+ inputs (:obj:`str`):
19
+ a string containing some text
20
+ Return:
21
+ A :obj:`PIL.Image` with the raw image representation as PIL.
22
+ """
23
+ class_vector = one_hot_from_names([inputs], batch_size=1)
24
+ if type(class_vector) == type(None):
25
+ raise ValueError("Input is not in ImageNet")
26
+ noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
27
+ noise_vector = torch.from_numpy(noise_vector)
28
+ class_vector = torch.from_numpy(class_vector)
29
+ with torch.no_grad():
30
+ output = self.model(noise_vector, class_vector, self.truncation)
31
+ # Scale image
32
+ img = output[0]
33
+ img = (img + 1) / 2.0
34
+ img = transforms.ToPILImage()(img)
35
+ return img