yamildiego commited on
Commit
4a1e480
1 Parent(s): 9e8370c
Files changed (1) hide show
  1. handler.py +26 -37
handler.py CHANGED
@@ -1,60 +1,49 @@
1
- from typing import Dict, List, Any
2
- import base64
3
- from PIL import Image
4
- from io import BytesIO
5
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
-
7
  import torch
8
- from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
9
-
10
 
11
- # # set device
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  if device.type != 'cuda':
14
- raise ValueError("need to run on GPU")
15
- # set mixed precision dtype
16
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
17
-
18
- class EndpointHandler():
19
- def __init__(self, path=""):
20
- # self.stable_diffusion_id = "Lykon/dreamshaper-8"
21
-
22
- # self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
23
- # self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
24
-
25
 
26
- self.generator = torch.Generator(device=device.type).manual_seed(3)
 
27
 
28
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
29
- # import torch
 
 
30
 
31
- device = "cuda"
 
32
  num_images_per_prompt = 1
33
 
34
- prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
35
- decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device)
 
36
 
37
- prompt = "Anthropomorphic cat dressed as a pilot"
38
- negative_prompt = ""
39
 
40
  prior_output = prior(
41
  prompt=prompt,
42
  height=512,
43
  width=512,
44
  negative_prompt=negative_prompt,
45
- guidance_scale=7.0,
 
46
  num_images_per_prompt=num_images_per_prompt,
47
- num_inference_steps=20
48
  )
 
49
  decoder_output = decoder(
50
- image_embeddings=prior_output.image_embeddings.half(),
51
  prompt=prompt,
52
  negative_prompt=negative_prompt,
53
- guidance_scale=7.0,
54
  output_type="pil",
55
- num_inference_steps=10
56
- ).images
57
- return decoder_output[0]
58
-
59
-
60
 
 
 
 
1
+ from typing import List, Any
 
 
 
 
 
2
  import torch
3
+ from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
 
4
 
5
+ # Configurar el dispositivo para ejecutar el modelo
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
  if device.type != 'cuda':
8
+ raise ValueError("Se requiere ejecutar en GPU")
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Configurar el tipo de dato mixto basado en la capacidad de la GPU
11
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability(device.index)[0] >= 8 else torch.float16
12
 
13
+ class EndpointHandler():
14
+ def __init__(self):
15
+ # Inicializar aquí si es necesario
16
+ pass
17
 
18
+ def __call__(self, data: Any) -> List[Any]:
19
+ # Configurar el número de imágenes por prompt
20
  num_images_per_prompt = 1
21
 
22
+ # Cargar los modelos con el tipo de dato y dispositivo correctos
23
+ prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype).to(device)
24
+ decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype).to(device)
25
 
26
+ prompt = data.get("inputs", "Una imagen interesante") # Asegúrate de pasar un prompt adecuado
27
+ negative_prompt = data.get("negative_prompt", "")
28
 
29
  prior_output = prior(
30
  prompt=prompt,
31
  height=512,
32
  width=512,
33
  negative_prompt=negative_prompt,
34
+ guidance_scale=7.5,
35
+ num_inference_steps=50,
36
  num_images_per_prompt=num_images_per_prompt,
 
37
  )
38
+
39
  decoder_output = decoder(
40
+ image_embeddings=prior_output["image_embeddings"].half(),
41
  prompt=prompt,
42
  negative_prompt=negative_prompt,
43
+ guidance_scale=7.5,
44
  output_type="pil",
45
+ num_inference_steps=20
46
+ )
 
 
 
47
 
48
+ # Asumiendo que quieres retornar la primera imagen
49
+ return [decoder_output.images[0]]