osanseviero HF staff commited on
Commit
b7d7804
1 Parent(s): c98e5e8

Add requirements and inference

Browse files
Files changed (2) hide show
  1. pipeline.py +38 -0
  2. requirements.txt +2 -0
pipeline.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nltk
3
+
4
+ from torchvision import transforms
5
+ from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
6
+
7
+ class PreTrainedPipeline():
8
+ def __init__(self, path=""):
9
+ """
10
+ Initialize model
11
+ """
12
+ nltk.download('wordnet')
13
+ self.model = BigGAN.from_pretrained(path)
14
+ self.truncation = 0.1
15
+
16
+
17
+ def __call__(self, inputs: str):
18
+ """
19
+ Args:
20
+ inputs (:obj:`str`):
21
+ a string containing some text
22
+ Return:
23
+ A :obj:`PIL.Image`. The raw image representation as PIL.
24
+ """
25
+ class_vector = one_hot_from_names([inputs], batch_size=1)
26
+ if class_vector == None:
27
+ raise ValueError("Input is not in ImageNet")
28
+
29
+
30
+ noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1)
31
+
32
+ noise_vector = torch.from_numpy(noise_vector)
33
+ class_vector = torch.from_numpy(class_vector)
34
+
35
+ with torch.no_grad():
36
+ output = model(noise_vector, class_vector, truncation)
37
+
38
+ return transforms.ToPILImage()(output[0])
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pytorch-pretrained-biggan
2
+ nltk