asoderznik commited on
Commit
7fd98b7
1 Parent(s): 6848e3e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -4
handler.py CHANGED
@@ -15,8 +15,8 @@ logger.warning("WARN")
15
  # set device
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
- if device.type != 'cuda':
19
- raise ValueError("need to run on GPU")
20
 
21
  class EndpointHandler():
22
  def __init__(self, path=""):
@@ -35,7 +35,6 @@ class EndpointHandler():
35
 
36
  self.pipe = StableDiffusionUpscalePipeline.from_pretrained(self.path, torch_dtype=torch.float16)
37
  self.pipe = self.pipe.to(device)
38
- self.pipe.unet.in_channels = 8
39
  logger.info('data received %s', data)
40
  inputs = data.get("inputs")
41
  logger.info('inputs received %s', inputs)
@@ -44,7 +43,7 @@ class EndpointHandler():
44
  logger.info('image_base64')
45
  image_bytes = BytesIO(image_base64)
46
  logger.info('image_bytes')
47
- image = Image.open(image_bytes)
48
  prompt = inputs['prompt']
49
  logger.info('image')
50
 
 
15
  # set device
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
+ #if device.type != 'cuda':
19
+ # raise ValueError("need to run on GPU")
20
 
21
  class EndpointHandler():
22
  def __init__(self, path=""):
 
35
 
36
  self.pipe = StableDiffusionUpscalePipeline.from_pretrained(self.path, torch_dtype=torch.float16)
37
  self.pipe = self.pipe.to(device)
 
38
  logger.info('data received %s', data)
39
  inputs = data.get("inputs")
40
  logger.info('inputs received %s', inputs)
 
43
  logger.info('image_base64')
44
  image_bytes = BytesIO(image_base64)
45
  logger.info('image_bytes')
46
+ image = Image.open(image_bytes).convert("RGB")
47
  prompt = inputs['prompt']
48
  logger.info('image')
49