Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| from .processors import Processor_id | |
| class ControlNetConfigUnit: | |
| def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False): | |
| self.processor_id = processor_id | |
| self.model_path = model_path | |
| self.scale = scale | |
| self.skip_processor = skip_processor | |
| class ControlNetUnit: | |
| def __init__(self, processor, model, scale=1.0): | |
| self.processor = processor | |
| self.model = model | |
| self.scale = scale | |
| class MultiControlNetManager: | |
| def __init__(self, controlnet_units=[]): | |
| self.processors = [unit.processor for unit in controlnet_units] | |
| self.models = [unit.model for unit in controlnet_units] | |
| self.scales = [unit.scale for unit in controlnet_units] | |
| def cpu(self): | |
| for model in self.models: | |
| model.cpu() | |
| def to(self, device): | |
| for model in self.models: | |
| model.to(device) | |
| for processor in self.processors: | |
| processor.to(device) | |
| def process_image(self, image, processor_id=None): | |
| if processor_id is None: | |
| processed_image = [processor(image) for processor in self.processors] | |
| else: | |
| processed_image = [self.processors[processor_id](image)] | |
| processed_image = torch.concat([ | |
| torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) | |
| for image_ in processed_image | |
| ], dim=0) | |
| return processed_image | |
| def __call__( | |
| self, | |
| sample, timestep, encoder_hidden_states, conditionings, | |
| tiled=False, tile_size=64, tile_stride=32, **kwargs | |
| ): | |
| res_stack = None | |
| for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): | |
| res_stack_ = model( | |
| sample, timestep, encoder_hidden_states, conditioning, **kwargs, | |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, | |
| processor_id=processor.processor_id | |
| ) | |
| res_stack_ = [res * scale for res in res_stack_] | |
| if res_stack is None: | |
| res_stack = res_stack_ | |
| else: | |
| res_stack = [i + j for i, j in zip(res_stack, res_stack_)] | |
| return res_stack | |
| class FluxMultiControlNetManager(MultiControlNetManager): | |
| def __init__(self, controlnet_units=[]): | |
| super().__init__(controlnet_units=controlnet_units) | |
| def process_image(self, image, processor_id=None): | |
| if processor_id is None: | |
| processed_image = [processor(image) for processor in self.processors] | |
| else: | |
| processed_image = [self.processors[processor_id](image)] | |
| return processed_image | |
| def __call__(self, conditionings, **kwargs): | |
| res_stack, single_res_stack = None, None | |
| for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): | |
| res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs) | |
| res_stack_ = [res * scale for res in res_stack_] | |
| single_res_stack_ = [res * scale for res in single_res_stack_] | |
| if res_stack is None: | |
| res_stack = res_stack_ | |
| single_res_stack = single_res_stack_ | |
| else: | |
| res_stack = [i + j for i, j in zip(res_stack, res_stack_)] | |
| single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] | |
| return res_stack, single_res_stack | |