m-ric HF staff commited on
Commit
45ce449
1 Parent(s): e26163a

Update image_transformation.py

Browse files
Files changed (1) hide show
  1. image_transformation.py +3 -3
image_transformation.py CHANGED
@@ -1,5 +1,5 @@
 
1
  import torch
2
-
3
  from transformers.tools.base import Tool
4
  from transformers.utils import (
5
  is_accelerate_available,
@@ -20,8 +20,8 @@ IMAGE_TRANSFORMATION_DESCRIPTION = (
20
  class ImageTransformationTool(Tool):
21
  default_stable_diffusion_checkpoint = "timbrooks/instruct-pix2pix"
22
  description = IMAGE_TRANSFORMATION_DESCRIPTION
23
- inputs = ['image', 'text']
24
- outputs = ['image']
25
 
26
  def __init__(self, device=None, controlnet=None, stable_diffusion=None, **hub_kwargs) -> None:
27
  if not is_accelerate_available():
 
1
+ from PIL import Image
2
  import torch
 
3
  from transformers.tools.base import Tool
4
  from transformers.utils import (
5
  is_accelerate_available,
 
20
  class ImageTransformationTool(Tool):
21
  default_stable_diffusion_checkpoint = "timbrooks/instruct-pix2pix"
22
  description = IMAGE_TRANSFORMATION_DESCRIPTION
23
+ inputs = {'image': Image.Image, 'prompt': str}
24
+ output_type = Image.Image
25
 
26
  def __init__(self, device=None, controlnet=None, stable_diffusion=None, **hub_kwargs) -> None:
27
  if not is_accelerate_available():