pictero / config.py
n42's picture
using class as config
58f80cc
raw
history blame
8.36 kB
import gradio as gr
import base64
import json
import torch
class Config:
def __init__(self):
self.code = {}
self.history = []
self.devices = []
def load_app_config(self):
try:
with open('appConfig.json', 'r') as f:
appConfig = json.load(f)
except FileNotFoundError:
print("App config file not found.")
except json.JSONDecodeError:
print("Error decoding JSON in app config file.")
except Exception as e:
print("An error occurred while loading app config:", str(e))
return appConfig
def set_inital_config(self):
appConfig = self.load_app_config()
self.model_configs = appConfig.get("models", {})
self.scheduler_configs = appConfig.get("schedulers", {})
# default device
self.devices = appConfig.get("devices", [])
device = None
data_type = 'float16'
allow_tensorfloat32 = False
if torch.cuda.is_available():
device = "cuda"
data_type = 'bfloat16'
allow_tensorfloat32 = True
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
self.current = {
"device": device,
"model": None,
"scheduler": None,
"variant": None,
"allow_tensorfloat32": allow_tensorfloat32,
"use_safetensors": False,
"data_type": data_type,
"safety_checker": False,
"requires_safety_checker": False,
"manual_seed": 42,
"inference_steps": 10,
"guidance_scale": 0.5,
"prompt": 'A white rabbit',
"negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
}
self.assemble_code()
# code output order
# self.code[self.code_pos_device] = f'device = "{device}"'
# self.code[self.code_pos_variant] = f'variant = {initial_config["variant"]}'
# self.code[self.code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config["allow_tensorfloat32"]}'
# self.code[self.code_pos_data_type] = 'data_type = torch.bfloat16'
# self.code[self.code_pos_init_pipeline] = 'sys.exit("No model selected!")'
# self.code[self.code_pos_safety_checker] = 'pipeline.safety_checker = None'
# self.code[self.code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config["requires_safety_checker"]}'
# self.code[self.code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
# self.code[self.code_pos_generator] = f'generator = torch.Generator("{device}")'
# self.code[self.code_pos_prompt] = f'prompt = "{initial_config["prompt"]}"'
# self.code[self.code_pos_negative_prompt] = f'negative_prompt = "{initial_config["negative_prompt"]}"'
# self.code[self.code_pos_inference_steps] = f'inference_steps = {initial_config["inference_steps"]}'
# self.code[self.code_pos_manual_seed] = f'manual_seed = {initial_config["inference_steps"]}'
# self.code[self.code_pos_guidance_scale] = f'guidance_scale = {initial_config["guidance_scale"]}'
# self.code[self.code_pos_run_inference] = f'''image = pipeline(
# prompt=prompt,
# negative_prompt=negative_prompt,
# generator=generator.manual_seed(manual_seed),
# num_inference_steps=inference_steps,
# guidance_scale=guidance_scale).images[0]'''
# return initial_config, devices, model_configs, scheduler_configs, self.code
def init_config(self, request: gr.Request, inital_config):
encoded_params = request.request.query_params.get('config')
return_config = {}
# get configuration from URL if GET parameter `share` is set
if encoded_params is not None:
decoded_params = base64.b64decode(encoded_params)
decoded_params = decoded_params.decode('utf-8')
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false')
dict_params = json.loads(decoded_params)
return_config = dict_params
# otherwise use default initial config
else:
inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
dict_inital_config = json.loads(inital_config)
return_config = dict_inital_config
return [return_config['model'],
return_config['device'],
return_config['use_safetensors'],
return_config['data_type'],
return_config['variant'],
return_config['safety_checker'],
return_config['requires_safety_checker'],
return_config['scheduler'],
return_config['prompt'],
return_config['negative_prompt'],
return_config['inference_steps'],
return_config['manual_seed'],
return_config['guidance_scale']
]
def set_config(self, key, value):
self.current[key] = value
return str(self.current)
def get_scheduler_description(self, scheduler):
if type(scheduler) != list and scheduler is not None:
return self.scheduler_configs[scheduler]
else:
return ''
def assemble_code(self):
self.code['001_code'] = f'''device = "{self.current['device']}"'''
if self.current['data_type'] == "bfloat16":
self.code['002_data_type'] = 'data_type = torch.bfloat16'
else:
self.code['002_data_type'] = 'data_type = torch.float16'
self.code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {self.current["allow_tensorfloat32"]}'
if str(self.current["variant"]) == 'None':
self.code['004_variant'] = f'variant = {self.current["variant"]}'
else:
self.code['004_variant'] = f'variant = "{self.current["variant"]}"'
self.code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained(
"{self.current['model']}",
use_safetensors=use_safetensors,
torch_dtype=data_type,
variant=variant).to(device)'''
self.code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {self.current["requires_safety_checker"]}'
if not self.current["safety_checker"] or str(self.current["safety_checker"]).lower == 'false':
self.code['055_safety_checker'] = f'pipeline.safety_checker = None'
else:
self.code['055_safety_checker'] = ''
self.code['060_scheduler'] = f'pipeline.scheduler = {self.current["scheduler"]}.from_config(pipeline.scheduler.config)'
if self.current['manual_seed'] < 0 or self.current['manual_seed'] is None or self.current['manual_seed'] == '':
self.code['070_generator'] = f'generator = torch.Generator("{self.current["device"]}")'
self.code['091_manual_seed'] = f'# manual_seed = {self.current["manual_seed"]}'
else:
self.code['070_generator'] = f'generator = torch.manual_seed(manual_seed)'
self.code['091_manual_seed'] = f'manual_seed = {self.current["manual_seed"]}'
self.code["080_prompt"] = f'prompt = {self.current["prompt"]}'
self.code["085_negative_prompt"] = f'negative_prompt = {self.current["negative_prompt"]}'
self.code["090_inference_steps"] = f'inference_steps = {self.current["inference_steps"]}'
self.code["095_guidance_scale"] = f'guidance_scale = {self.current["guidance_scale"]}'
self.code["100_run_inference"] = f'''image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
generator={self.code["070_generator"]},
num_inference_steps=inference_steps,
guidance_scale=guidance_scale).images[0]'''
return '\r\n'.join(value[1] for value in sorted(self.code.items()))