osanseviero HF staff commited on
Commit
740f729
1 Parent(s): af0640a

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -10
pipeline.py CHANGED
@@ -16,29 +16,26 @@ class PreTrainedPipeline():
16
  self.model = BigGAN.from_pretrained(path)
17
  self.truncation = 0.1
18
 
19
-
20
- def __call__(self, inputs: str) -> str:
21
  """
22
  Args:
23
  inputs (:obj:`str`):
24
  a string containing some text
25
  Return:
26
- A :obj:`np.array`. A np.array containing the image information.
27
  """
28
  class_vector = one_hot_from_names([inputs], batch_size=1)
29
  if type(class_vector) == type(None):
30
  raise ValueError("Input is not in ImageNet")
31
-
32
  noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
33
-
34
  noise_vector = torch.from_numpy(noise_vector)
35
  class_vector = torch.from_numpy(class_vector)
36
-
37
  with torch.no_grad():
38
- output = self.model(noise_vector, class_vector, self.truncation)
39
 
40
  img = transforms.ToPILImage()(output[0])
41
- buf = io.BytesIO()
42
- img.save(buf, format="JPEG")
 
43
 
44
- return base64.encodebytes(buf.getvalue()).decode('utf-8')
 
16
  self.model = BigGAN.from_pretrained(path)
17
  self.truncation = 0.1
18
 
19
+ def __call__(self, inputs: str):
 
20
  """
21
  Args:
22
  inputs (:obj:`str`):
23
  a string containing some text
24
  Return:
25
+ A :obj:`PIL.Image`. The raw image representation as PIL.
26
  """
27
  class_vector = one_hot_from_names([inputs], batch_size=1)
28
  if type(class_vector) == type(None):
29
  raise ValueError("Input is not in ImageNet")
 
30
  noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
 
31
  noise_vector = torch.from_numpy(noise_vector)
32
  class_vector = torch.from_numpy(class_vector)
 
33
  with torch.no_grad():
34
+ output = self.model(noise_vector, class_vector, truncation)
35
 
36
  img = transforms.ToPILImage()(output[0])
37
+ buffer = BytesIO()
38
+ img.save(buffer, format="JPEG")
39
+ img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
40
 
41
+ return img_str