latent-upscaler-tool / image_upscaling.py
lysandre's picture
lysandre HF staff
Update image_upscaling.py
612bcea
import numpy as np
import torch
from transformers.tools.base import Tool, get_default_device
from transformers.utils import is_accelerate_available
from diffusers import DiffusionPipeline
IMAGE_UPSCALING_DESCRIPTION = (
"This is a tool that upscales an image. It takes one input: `image`, which should be "
"the image to upscale. It returns the upscaled image."
)
class ImageUpscalingTool(Tool):
default_stable_diffusion_checkpoint = "stabilityai/sd-x2-latent-upscaler"
description = IMAGE_UPSCALING_DESCRIPTION
name = "image_upscaler"
inputs = ['image']
outputs = ['image']
def __init__(self, device=None, controlnet=None, stable_diffusion=None, **hub_kwargs) -> None:
if not is_accelerate_available():
raise ImportError("Accelerate should be installed in order to use tools.")
super().__init__()
self.stable_diffusion = self.default_stable_diffusion_checkpoint
self.device = device
self.hub_kwargs = hub_kwargs
def setup(self):
if self.device is None:
self.device = get_default_device()
self.pipeline = DiffusionPipeline.from_pretrained(self.stable_diffusion)
self.pipeline.to(self.device)
if self.device.type == "cuda":
self.pipeline.to(torch_dtype=torch.float16)
self.is_initialized = True
def __call__(self, image):
if not self.is_initialized:
self.setup()
return self.pipeline(
image=image,
prompt="",
num_inference_steps=30,
guidance_scale=0,
).images[0]