from typing import Union from PIL import Image import gradio as gr from modules.shared import log from modules.control import processors from modules.control.units import controlnet from modules.control.units import xs from modules.control.units import lite from modules.control.units import t2iadapter from modules.control.units import reference # pylint: disable=unused-import default_device = None default_dtype = None class Unit(): # mashup of gradio controls and mapping to actual implementation classes def __init__(self, # values enabled: bool = None, strength: float = None, unit_type: str = None, start: float = 0, end: float = 1, # ui bindings enabled_cb = None, reset_btn = None, process_id = None, preview_btn = None, model_id = None, model_strength = None, preview_process = None, image_upload = None, image_reuse = None, image_preview = None, control_start = None, control_end = None, result_txt = None, extra_controls: list = [], ): self.enabled = enabled or False self.type = unit_type self.strength = strength or 1.0 self.start = start or 0 self.end = end or 1 self.start = min(self.start, self.end) self.end = max(self.start, self.end) # processor always exists, adapter and controlnet are optional self.process: processors.Processor = processors.Processor() self.adapter: t2iadapter.Adapter = None self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None # map to input image self.override: Image = None # global settings but passed per-unit self.factor = 1.0 self.guess = False self.start = 0 self.end = 1 # reference settings self.attention = 'Attention' self.fidelity = 0.5 self.query_weight = 1.0 self.adain_weight = 1.0 def reset(): if self.process is not None: self.process.reset() if self.adapter is not None: self.adapter.reset() if self.controlnet is not None: self.controlnet.reset() self.override = None return [True, 'None', 'None', 1.0] # reset ui values def enabled_change(val): self.enabled = val def strength_change(val): self.strength = val def control_change(start, end): self.start = min(start, end) self.end = max(start, end) def adapter_extra(c1): self.factor = c1 def controlnet_extra(c1): self.guess = c1 def controlnetxs_extra(_c1): pass # gr.component passed directly to load method def reference_extra(c1, c2, c3, c4): self.attention = c1 self.fidelity = c2 self.query_weight = c3 self.adain_weight = c4 def upload_image(image_file): if image_file is None: return gr.update(value=None) try: self.process.override = Image.open(image_file.name) self.override = self.process.override log.debug(f'Control process upload image: path="{image_file.name}" image={self.process.override}') return gr.update(visible=self.process.override is not None, value=self.process.override) except Exception as e: log.error(f'Control process upload image failed: path="{image_file.name}" error={e}') return gr.update(visible=False, value=None) def reuse_image(image): log.debug(f'Control process reuse image: {image}') self.process.override = image self.override = self.process.override return gr.update(visible=self.process.override is not None, value=self.process.override) # actual init if self.type == 't2i adapter': self.adapter = t2iadapter.Adapter(device=default_device, dtype=default_dtype) elif self.type == 'controlnet': self.controlnet = controlnet.ControlNet(device=default_device, dtype=default_dtype) elif self.type == 'xs': self.controlnet = xs.ControlNetXS(device=default_device, dtype=default_dtype) elif self.type == 'lite': self.controlnet = lite.ControlLLLite(device=default_device, dtype=default_dtype) elif self.type == 'reference': pass elif self.type == 'ip': pass else: log.error(f'Control unknown type: unit={unit_type}') return # bind ui controls to properties if present if self.type == 't2i adapter': if model_id is not None: model_id.change(fn=self.adapter.load, inputs=[model_id], outputs=[result_txt], show_progress=True) if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=adapter_extra, inputs=extra_controls) elif self.type == 'controlnet': if model_id is not None: model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True) if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=controlnet_extra, inputs=extra_controls) elif self.type == 'xs': if model_id is not None: model_id.change(fn=self.controlnet.load, inputs=[model_id, extra_controls[0]], outputs=[result_txt], show_progress=True) if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) elif self.type == 'lite': if model_id is not None: model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True) if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) elif self.type == 'reference': if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=reference_extra, inputs=extra_controls) extra_controls[1].change(fn=reference_extra, inputs=extra_controls) extra_controls[2].change(fn=reference_extra, inputs=extra_controls) extra_controls[3].change(fn=reference_extra, inputs=extra_controls) if enabled_cb is not None: enabled_cb.change(fn=enabled_change, inputs=[enabled_cb]) if model_strength is not None: model_strength.change(fn=strength_change, inputs=[model_strength]) if process_id is not None: process_id.change(fn=self.process.load, inputs=[process_id], outputs=[result_txt], show_progress=True) if reset_btn is not None: reset_btn.click(fn=reset, inputs=[], outputs=[enabled_cb, model_id, process_id, model_strength]) if preview_btn is not None: preview_btn.click(fn=self.process.preview, inputs=[], outputs=[preview_process]) # return list of images for gallery if image_upload is not None: image_upload.upload(fn=upload_image, inputs=[image_upload], outputs=[image_preview]) # return list of images for gallery if image_reuse is not None: image_reuse.click(fn=reuse_image, inputs=[preview_process], outputs=[image_preview]) # return list of images for gallery if control_start is not None and control_end is not None: control_start.change(fn=control_change, inputs=[control_start, control_end]) control_end.change(fn=control_change, inputs=[control_start, control_end])