Spaces:
Runtime error
Runtime error
from typing import Union | |
from PIL import Image | |
import gradio as gr | |
from modules.shared import log | |
from modules.control import processors | |
from modules.control.units import controlnet | |
from modules.control.units import xs | |
from modules.control.units import lite | |
from modules.control.units import t2iadapter | |
from modules.control.units import reference # pylint: disable=unused-import | |
default_device = None | |
default_dtype = None | |
class Unit(): # mashup of gradio controls and mapping to actual implementation classes | |
def __init__(self, | |
# values | |
enabled: bool = None, | |
strength: float = None, | |
unit_type: str = None, | |
start: float = 0, | |
end: float = 1, | |
# ui bindings | |
enabled_cb = None, | |
reset_btn = None, | |
process_id = None, | |
preview_btn = None, | |
model_id = None, | |
model_strength = None, | |
preview_process = None, | |
image_upload = None, | |
image_reuse = None, | |
image_preview = None, | |
control_start = None, | |
control_end = None, | |
result_txt = None, | |
extra_controls: list = [], | |
): | |
self.enabled = enabled or False | |
self.type = unit_type | |
self.strength = strength or 1.0 | |
self.start = start or 0 | |
self.end = end or 1 | |
self.start = min(self.start, self.end) | |
self.end = max(self.start, self.end) | |
# processor always exists, adapter and controlnet are optional | |
self.process: processors.Processor = processors.Processor() | |
self.adapter: t2iadapter.Adapter = None | |
self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None | |
# map to input image | |
self.override: Image = None | |
# global settings but passed per-unit | |
self.factor = 1.0 | |
self.guess = False | |
self.start = 0 | |
self.end = 1 | |
# reference settings | |
self.attention = 'Attention' | |
self.fidelity = 0.5 | |
self.query_weight = 1.0 | |
self.adain_weight = 1.0 | |
def reset(): | |
if self.process is not None: | |
self.process.reset() | |
if self.adapter is not None: | |
self.adapter.reset() | |
if self.controlnet is not None: | |
self.controlnet.reset() | |
self.override = None | |
return [True, 'None', 'None', 1.0] # reset ui values | |
def enabled_change(val): | |
self.enabled = val | |
def strength_change(val): | |
self.strength = val | |
def control_change(start, end): | |
self.start = min(start, end) | |
self.end = max(start, end) | |
def adapter_extra(c1): | |
self.factor = c1 | |
def controlnet_extra(c1): | |
self.guess = c1 | |
def controlnetxs_extra(_c1): | |
pass # gr.component passed directly to load method | |
def reference_extra(c1, c2, c3, c4): | |
self.attention = c1 | |
self.fidelity = c2 | |
self.query_weight = c3 | |
self.adain_weight = c4 | |
def upload_image(image_file): | |
if image_file is None: | |
return gr.update(value=None) | |
try: | |
self.process.override = Image.open(image_file.name) | |
self.override = self.process.override | |
log.debug(f'Control process upload image: path="{image_file.name}" image={self.process.override}') | |
return gr.update(visible=self.process.override is not None, value=self.process.override) | |
except Exception as e: | |
log.error(f'Control process upload image failed: path="{image_file.name}" error={e}') | |
return gr.update(visible=False, value=None) | |
def reuse_image(image): | |
log.debug(f'Control process reuse image: {image}') | |
self.process.override = image | |
self.override = self.process.override | |
return gr.update(visible=self.process.override is not None, value=self.process.override) | |
# actual init | |
if self.type == 't2i adapter': | |
self.adapter = t2iadapter.Adapter(device=default_device, dtype=default_dtype) | |
elif self.type == 'controlnet': | |
self.controlnet = controlnet.ControlNet(device=default_device, dtype=default_dtype) | |
elif self.type == 'xs': | |
self.controlnet = xs.ControlNetXS(device=default_device, dtype=default_dtype) | |
elif self.type == 'lite': | |
self.controlnet = lite.ControlLLLite(device=default_device, dtype=default_dtype) | |
elif self.type == 'reference': | |
pass | |
elif self.type == 'ip': | |
pass | |
else: | |
log.error(f'Control unknown type: unit={unit_type}') | |
return | |
# bind ui controls to properties if present | |
if self.type == 't2i adapter': | |
if model_id is not None: | |
model_id.change(fn=self.adapter.load, inputs=[model_id], outputs=[result_txt], show_progress=True) | |
if extra_controls is not None and len(extra_controls) > 0: | |
extra_controls[0].change(fn=adapter_extra, inputs=extra_controls) | |
elif self.type == 'controlnet': | |
if model_id is not None: | |
model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True) | |
if extra_controls is not None and len(extra_controls) > 0: | |
extra_controls[0].change(fn=controlnet_extra, inputs=extra_controls) | |
elif self.type == 'xs': | |
if model_id is not None: | |
model_id.change(fn=self.controlnet.load, inputs=[model_id, extra_controls[0]], outputs=[result_txt], show_progress=True) | |
if extra_controls is not None and len(extra_controls) > 0: | |
extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) | |
elif self.type == 'lite': | |
if model_id is not None: | |
model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True) | |
if extra_controls is not None and len(extra_controls) > 0: | |
extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) | |
elif self.type == 'reference': | |
if extra_controls is not None and len(extra_controls) > 0: | |
extra_controls[0].change(fn=reference_extra, inputs=extra_controls) | |
extra_controls[1].change(fn=reference_extra, inputs=extra_controls) | |
extra_controls[2].change(fn=reference_extra, inputs=extra_controls) | |
extra_controls[3].change(fn=reference_extra, inputs=extra_controls) | |
if enabled_cb is not None: | |
enabled_cb.change(fn=enabled_change, inputs=[enabled_cb]) | |
if model_strength is not None: | |
model_strength.change(fn=strength_change, inputs=[model_strength]) | |
if process_id is not None: | |
process_id.change(fn=self.process.load, inputs=[process_id], outputs=[result_txt], show_progress=True) | |
if reset_btn is not None: | |
reset_btn.click(fn=reset, inputs=[], outputs=[enabled_cb, model_id, process_id, model_strength]) | |
if preview_btn is not None: | |
preview_btn.click(fn=self.process.preview, inputs=[], outputs=[preview_process]) # return list of images for gallery | |
if image_upload is not None: | |
image_upload.upload(fn=upload_image, inputs=[image_upload], outputs=[image_preview]) # return list of images for gallery | |
if image_reuse is not None: | |
image_reuse.click(fn=reuse_image, inputs=[preview_process], outputs=[image_preview]) # return list of images for gallery | |
if control_start is not None and control_end is not None: | |
control_start.change(fn=control_change, inputs=[control_start, control_end]) | |
control_end.change(fn=control_change, inputs=[control_start, control_end]) | |