|
import gradio as gr |
|
import base64 |
|
import json |
|
import torch |
|
|
|
class Config: |
|
|
|
def __init__(self): |
|
|
|
self.code = {} |
|
self.history = [] |
|
self.devices = [] |
|
|
|
def load_app_config(self): |
|
try: |
|
with open('appConfig.json', 'r') as f: |
|
appConfig = json.load(f) |
|
except FileNotFoundError: |
|
print("App config file not found.") |
|
except json.JSONDecodeError: |
|
print("Error decoding JSON in app config file.") |
|
except Exception as e: |
|
print("An error occurred while loading app config:", str(e)) |
|
|
|
return appConfig |
|
|
|
def set_inital_config(self): |
|
|
|
appConfig = self.load_app_config() |
|
|
|
self.model_configs = appConfig.get("models", {}) |
|
self.scheduler_configs = appConfig.get("schedulers", {}) |
|
|
|
|
|
self.devices = appConfig.get("devices", []) |
|
device = None |
|
data_type = 'float16' |
|
allow_tensorfloat32 = False |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
data_type = 'bfloat16' |
|
allow_tensorfloat32 = True |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
|
|
self.current = { |
|
"device": device, |
|
"model": None, |
|
"scheduler": None, |
|
"variant": None, |
|
"allow_tensorfloat32": allow_tensorfloat32, |
|
"use_safetensors": False, |
|
"data_type": data_type, |
|
"safety_checker": False, |
|
"requires_safety_checker": False, |
|
"manual_seed": 42, |
|
"inference_steps": 10, |
|
"guidance_scale": 0.5, |
|
"prompt": 'A white rabbit', |
|
"negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly', |
|
} |
|
|
|
self.assemble_code() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_config(self, request: gr.Request, inital_config): |
|
|
|
encoded_params = request.request.query_params.get('config') |
|
return_config = {} |
|
|
|
|
|
if encoded_params is not None: |
|
decoded_params = base64.b64decode(encoded_params) |
|
decoded_params = decoded_params.decode('utf-8') |
|
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false') |
|
dict_params = json.loads(decoded_params) |
|
|
|
return_config = dict_params |
|
|
|
|
|
else: |
|
|
|
inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false') |
|
dict_inital_config = json.loads(inital_config) |
|
|
|
return_config = dict_inital_config |
|
|
|
return [return_config['model'], |
|
return_config['device'], |
|
return_config['use_safetensors'], |
|
return_config['data_type'], |
|
return_config['variant'], |
|
return_config['safety_checker'], |
|
return_config['requires_safety_checker'], |
|
return_config['scheduler'], |
|
return_config['prompt'], |
|
return_config['negative_prompt'], |
|
return_config['inference_steps'], |
|
return_config['manual_seed'], |
|
return_config['guidance_scale'] |
|
] |
|
|
|
def set_config(self, key, value): |
|
|
|
self.current[key] = value |
|
|
|
return str(self.current) |
|
|
|
def get_scheduler_description(self, scheduler): |
|
|
|
if type(scheduler) != list and scheduler is not None: |
|
|
|
return self.scheduler_configs[scheduler] |
|
|
|
else: |
|
|
|
return '' |
|
|
|
def assemble_code(self): |
|
|
|
self.code['001_code'] = f'''device = "{self.current['device']}"''' |
|
if self.current['data_type'] == "bfloat16": |
|
self.code['002_data_type'] = 'data_type = torch.bfloat16' |
|
else: |
|
self.code['002_data_type'] = 'data_type = torch.float16' |
|
self.code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {self.current["allow_tensorfloat32"]}' |
|
if str(self.current["variant"]) == 'None': |
|
self.code['004_variant'] = f'variant = {self.current["variant"]}' |
|
else: |
|
self.code['004_variant'] = f'variant = "{self.current["variant"]}"' |
|
self.code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained( |
|
"{self.current['model']}", |
|
use_safetensors=use_safetensors, |
|
torch_dtype=data_type, |
|
variant=variant).to(device)''' |
|
|
|
self.code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {self.current["requires_safety_checker"]}' |
|
|
|
if not self.current["safety_checker"] or str(self.current["safety_checker"]).lower == 'false': |
|
self.code['055_safety_checker'] = f'pipeline.safety_checker = None' |
|
else: |
|
self.code['055_safety_checker'] = '' |
|
|
|
self.code['060_scheduler'] = f'pipeline.scheduler = {self.current["scheduler"]}.from_config(pipeline.scheduler.config)' |
|
|
|
if self.current['manual_seed'] < 0 or self.current['manual_seed'] is None or self.current['manual_seed'] == '': |
|
self.code['070_generator'] = f'generator = torch.Generator("{self.current["device"]}")' |
|
self.code['091_manual_seed'] = f'# manual_seed = {self.current["manual_seed"]}' |
|
else: |
|
self.code['070_generator'] = f'generator = torch.manual_seed(manual_seed)' |
|
self.code['091_manual_seed'] = f'manual_seed = {self.current["manual_seed"]}' |
|
|
|
self.code["080_prompt"] = f'prompt = {self.current["prompt"]}' |
|
self.code["085_negative_prompt"] = f'negative_prompt = {self.current["negative_prompt"]}' |
|
self.code["090_inference_steps"] = f'inference_steps = {self.current["inference_steps"]}' |
|
self.code["095_guidance_scale"] = f'guidance_scale = {self.current["guidance_scale"]}' |
|
|
|
self.code["100_run_inference"] = f'''image = pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
generator={self.code["070_generator"]}, |
|
num_inference_steps=inference_steps, |
|
guidance_scale=guidance_scale).images[0]''' |
|
|
|
return '\r\n'.join(value[1] for value in sorted(self.code.items())) |