import gradio as gr import base64 import json import torch from gradio import Request from gradio.context import Context def persist(component): sessions = {} def resume_session(value, request: Request): return sessions.get(request.username, value) def update_session(value, request: Request): sessions[request.username] = value Context.root_block.load(resume_session, inputs=[component], outputs=component) component.change(update_session, inputs=[component]) return component def get_initial_config(): device = None data_type = 'float16' allow_tensorfloat32 = "False" if torch.cuda.is_available(): device = "cuda" data_type = 'bfloat16' allow_tensorfloat32 = "True" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" config = { "device": device, "model": None, "cpu_offload": "False", "scheduler": None, "variant": None, "allow_tensorfloat32": allow_tensorfloat32, "use_safetensors": "False", "data_type": data_type, "refiner": "none", "safety_checker": "False", "requires_safety_checker": "False", "auto_encoder": None, "enable_vae_slicing": "True", "enable_vae_tiling": "True", "manual_seed": 42, "inference_steps": 10, "guidance_scale": 5, "adapter_textual_inversion": None, "adapter_textual_inversion_token": None, "adapter_lora": [], "adapter_lora_token": [], "adapter_lora_weight": [], "adapter_lora_balancing": {}, "lora_scale": 0.5, "prompt": 'A white rabbit', "trigger_token": '', "negative_prompt": 'lowres, cropped, worst quality, low quality', } return config def get_config_from_url(initial_config, request: Request): encoded_params = request.request.query_params.get('config') return_config = {} # get configuration from URL if GET parameter `share` is set if encoded_params is not None: decoded_params = base64.b64decode(encoded_params) decoded_params = decoded_params.decode('utf-8') decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False') dict_params = json.loads(decoded_params) return_config = dict_params # otherwise use default initial config else: # check if a cookie exists for our initial parameters for key in initial_config.keys(): if key in request.cookies: value = request.cookies[key] # transform empty values to a "Python-like" None if value == 'null' or value == '': value = None # if value expected to be a list, transform the string to list if type(initial_config[key]) == list: value = json.loads(value) initial_config[key] = value return_config = initial_config return [return_config['model'], return_config['device'], return_config['cpu_offload'], return_config['use_safetensors'], return_config['data_type'], return_config['refiner'], return_config['variant'], return_config['safety_checker'], return_config['requires_safety_checker'], return_config['auto_encoder'], return_config['enable_vae_slicing'], return_config['enable_vae_tiling'], return_config['scheduler'], return_config['prompt'], return_config['trigger_token'], return_config['negative_prompt'], return_config['inference_steps'], return_config['manual_seed'], return_config['guidance_scale'], return_config['adapter_textual_inversion'], return_config['adapter_textual_inversion_token'], return_config['adapter_lora'], return_config['adapter_lora_token'], return_config['adapter_lora_weight'], return_config['adapter_lora_balancing'], return_config['lora_scale'] ] def load_app_config(): try: with open('appConfig.json', 'r') as f: appConfig = json.load(f) except FileNotFoundError: print("App config file not found.") except json.JSONDecodeError: print("Error decoding JSON in app config file.") except Exception as e: print("An error occurred while loading app config:", str(e)) return appConfig def set_config(config, key, value): if str(value).lower() == 'null' or str(value).lower() == 'none': value = '' config[key] = value return config def assemble_code(str_config): config = str_config code = [] code.append(f'''device = "{config['device']}"''') if config['data_type'] == "bfloat16": code.append('data_type = torch.bfloat16') else: code.append('data_type = torch.float16') code.append(f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}') if str(config["variant"]) == 'None': code.append(f'variant = {config["variant"]}') else: code.append(f'variant = "{config["variant"]}"') code.append(f'''use_safetensors = {config["use_safetensors"]}''') # INIT PIPELINE code.append(f'''pipeline = DiffusionPipeline.from_pretrained( "{config['model']}", use_safetensors=use_safetensors, torch_dtype=data_type, variant=variant).to(device)''') if str(config["cpu_offload"]).lower() != 'false': code.append("pipeline.enable_model_cpu_offload()") # AUTO ENCODER if str(config["auto_encoder"]).lower() != 'none': code.append(f'pipeline.vae = AutoencoderKL.from_pretrained("{config["auto_encoder"]}", torch_dtype=data_type).to(device)') if str(config["enable_vae_slicing"]).lower() != 'false': code.append("pipeline.enable_vae_slicing()") if str(config["enable_vae_tiling"]).lower() != 'false': code.append("pipeline.enable_vae_tiling()") # INIT REFINER if str(config['refiner']).lower() != 'none': code.append(f'''refiner = DiffusionPipeline.from_pretrained( "{config['refiner']}", text_encoder_2 = base.text_encoder_2, vae = base.vae, torch_dtype = data_type, use_safetensors = use_safetensors, variant=variant, ).to(device)''') if str(config["cpu_offload"]).lower() != 'false': code.append("refiner.enable_model_cpu_offload()") if str(config["enable_vae_slicing"]).lower() != 'false': code.append("refiner.enable_vae_slicing()") if str(config["enable_vae_tiling"]).lower() != 'false': code.append("refiner.enable_vae_tiling()") # SAFETY CHECKER code.append(f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}') if str(config["safety_checker"]).lower() == 'false': code.append(f'pipeline.safety_checker = None') # SCHEDULER/SOLVER if str(config["scheduler"]).lower() != 'none': code.append(f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)') # MANUAL SEED/GENERATOR if config['manual_seed'] is None or config['manual_seed'] == '' or int(config['manual_seed']) < 0: code.append(f'# manual_seed = {config["manual_seed"]}') code.append(f'generator = None') else: code.append(f'manual_seed = {config["manual_seed"]}') code.append(f'generator = torch.manual_seed(manual_seed)') # ADAPTER if str(config["adapter_textual_inversion"]).lower() != 'none' and str(config["adapter_textual_inversion"]).lower() != 'null' and str(config["adapter_textual_inversion"]).lower() != '': code.append(f'pipeline.load_textual_inversion("{config["adapter_textual_inversion"]}", token="{config["adapter_textual_inversion_token"]}")') if len(config["adapter_lora"]) > 0 and len(config["adapter_lora"]) == len(config["adapter_lora_weight"]): adapter_lora_balancing = [] for adapter_lora_index, adapter_lora in enumerate(config["adapter_lora"]): if str(config["adapter_lora_weight"][adapter_lora_index]).lower() != 'none': code.append(f'pipeline.load_lora_weights("{adapter_lora}", weight_name="{config["adapter_lora_weight"][adapter_lora_index]}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")') else: code.append(f'pipeline.load_lora_weights("{adapter_lora}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")') adapter_lora_balancing.append(config["adapter_lora_balancing"][adapter_lora]) code.append(f'adapter_weights = {adapter_lora_balancing}') code.append(f'pipeline.set_adapters({config["adapter_lora_token"]}, adapter_weights=adapter_weights)') cross_attention_kwargs = '{"scale": ' + config["lora_scale"] + '}' else: cross_attention_kwargs = 'None' code.append(f'prompt = "{config["prompt"]} {config["trigger_token"]} {config["adapter_textual_inversion_token"]} {", ".join(config["adapter_lora_token"])}"') code.append(f'negative_prompt = "{config["negative_prompt"]}"') code.append(f'inference_steps = {config["inference_steps"]}') code.append(f'guidance_scale = {config["guidance_scale"]}') code.append(f'''image = pipeline( prompt=prompt, negative_prompt=negative_prompt, generator=generator, num_inference_steps=inference_steps, cross_attention_kwargs={cross_attention_kwargs}, guidance_scale=guidance_scale).images ''') if str(config['refiner']).lower() != 'none': code.append(f'''image = refiner( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=inference_steps, image=image ).images[0]''') code.append('image[0]') return '\r\n'.join(code)