patrickvonplaten commited on
Commit
936baa4
1 Parent(s): 69b946e

Improve text to image (#1)

Browse files

- Improve text to image (681b7e4a4612b7467a146d325d0a17be155613bb)

Files changed (1) hide show
  1. text_to_image.py +8 -3
text_to_image.py CHANGED
@@ -1,8 +1,9 @@
1
  from transformers.tools.base import Tool, get_default_device
2
  from transformers.utils import is_accelerate_available, is_diffusers_available
 
3
 
4
  if is_diffusers_available():
5
- from diffusers import DiffusionPipeline
6
 
7
 
8
  TEXT_TO_IMAGE_DESCRIPTION = (
@@ -34,13 +35,17 @@ class TextToImageTool(Tool):
34
  self.device = get_default_device()
35
 
36
  self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
 
37
  self.pipeline.to(self.device)
38
 
 
 
 
39
  self.is_initialized = True
40
 
41
- def __call__(self, prompt):
42
  if not self.is_initialized:
43
  self.setup()
44
 
45
- return self.pipeline(prompt).images[0]
46
 
 
1
  from transformers.tools.base import Tool, get_default_device
2
  from transformers.utils import is_accelerate_available, is_diffusers_available
3
+ import torch
4
 
5
  if is_diffusers_available():
6
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
7
 
8
 
9
  TEXT_TO_IMAGE_DESCRIPTION = (
 
35
  self.device = get_default_device()
36
 
37
  self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
38
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
39
  self.pipeline.to(self.device)
40
 
41
+ if self.device.type == "cuda":
42
+ self.pipeline.to(torch_dtype=torch.float16)
43
+
44
  self.is_initialized = True
45
 
46
+ def __call__(self, prompt, negative_prompt="low quality, bad quality, deformed, low resolution", added_prompt=" , highest quality, highly realistic, very high resolution"):
47
  if not self.is_initialized:
48
  self.setup()
49
 
50
+ return self.pipeline(prompt + added_prompt, num_inference_steps=25).images[0]
51