import datetime import gradio import subprocess from PIL import Image import torch, torch.backends.cudnn, torch.backends.cuda from min_dalle import MinDalle from emoji import demojize import string def filename_from_text(text: str) -> str: text = demojize(text, delimiters=['', '']) text = text.lower().encode('ascii', errors='ignore').decode() allowed_chars = string.ascii_lowercase + ' ' text = ''.join(i for i in text.lower() if i in allowed_chars) text = text[:64] text = '-'.join(text.strip().split()) if len(text) == 0: text = 'blank' return text def log_gpu_memory(): print("Date:{}, GPU memory:{}".format(str(datetime.datetime.now()), subprocess.check_output('nvidia-smi').decode('utf-8'))) log_gpu_memory() model = MinDalle( is_mega=True, is_reusable=True, device='cuda', dtype=torch.float32 ) log_gpu_memory() def run_model( text: str, grid_size: int, is_seamless: bool, save_as_png: bool, temperature: float, supercondition: str, top_k: str ) -> str: torch.set_grad_enabled(False) torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True print("Date:{}".format(str(datetime.datetime.now()))) print('text:', text) print('grid_size:', grid_size) print('is_seamless:', is_seamless) print('temperature:', temperature) print('supercondition:', supercondition) print('top_k:', top_k) try: temperature = float(temperature) assert(temperature > 1e-6) except: raise Exception('Temperature must be a positive nonzero number') try: grid_size = int(grid_size) assert(grid_size <= 5) assert(grid_size >= 1) except: raise Exception('Grid size must be between 1 and 5') try: top_k = int(top_k) assert(top_k <= 16384) assert(top_k >= 1) except: raise Exception('Top k must be between 1 and 16384') with torch.no_grad(): image = model.generate_image( text = text, seed = -1, grid_size = grid_size, is_seamless = bool(is_seamless), temperature = temperature, supercondition_factor = float(supercondition), top_k = top_k, is_verbose = True ) log_gpu_memory() ext = 'png' if bool(save_as_png) else 'jpg' filename = filename_from_text(text) image_path = '{}.{}'.format(filename, ext) image.save(image_path) return image_path demo = gradio.Blocks(analytics_enabled=True) with demo: with gradio.Row(): with gradio.Column(): input_text = gradio.Textbox( label='Input Text', value='Moai statue giving a TED Talk', lines=3 ) run_button = gradio.Button(value='Generate Image').style(full_width=True) ''' output_image = gradio.Image( value='examples/moai-statue.jpg', label='Output Image', type='file', interactive=False ) ''' with gradio.Column(): gradio.Markdown('## Settings') with gradio.Row(): grid_size = gradio.Slider( label='Grid Size', value=5, minimum=1, maximum=5, step=1 ) save_as_png = gradio.Checkbox( label='Output PNG', value=False ) is_seamless = gradio.Checkbox( label='Seamless', value=False ) gradio.Markdown('#### Advanced') with gradio.Row(): temperature = gradio.Number( label='Temperature', value=1 ) top_k = gradio.Dropdown( label='Top-k', choices=[str(2 ** i) for i in range(15)], value='128' ) supercondition = gradio.Dropdown( label='Super Condition', choices=[str(2 ** i) for i in range(2, 7)], value='16' ) gradio.Markdown( """ #### - **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image. - **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds. - **Seamless**: Tile images in image token space instead of pixel space. - **Temperature**: High temperature increases the probability of sampling low scoring image tokens. - **Top-k**: Each image token is sampled from the top-k scoring tokens. - **Super Condition**: Higher values can result in better agreement with the text. """ ) gradio.Examples( examples=[ #['Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', 3, 'examples/rusty-iron-man.jpg'], #['Moai statue giving a TED Talk', 5, 'examples/moai-statue.jpg'], #['Court sketch of Godzilla on trial', 5, 'examples/godzilla-trial.jpg'], #['lofi nuclear war to relax and study to', 5, 'examples/lofi-nuclear-war.jpg'], #['Karl Marx slimed at Kids Choice Awards', 4, 'examples/marx-slimed.jpg'], #['Scientists trying to rhyme orange with banana', 4, 'examples/scientists-rhyme.jpg'], #['Jesus turning water into wine on Americas Got Talent', 5, 'examples/jesus-talent.jpg'], #['Elmo in a street riot throwing a Molotov cocktail, hyperrealistic', 5, 'examples/elmo-riot.jpg'], #['Trail cam footage of gollum eating watermelon', 4, 'examples/gollum.jpg'], #['Funeral at Whole Foods', 4, 'examples/funeral-whole-foods.jpg'], #['Singularity, hyperrealism', 5, 'examples/singularity.jpg'], #['Astronaut riding a horse hyperrealistic', 5, 'examples/astronaut-horse.jpg'], ['Astronaut riding a horse hyperrealistic', 1], #['An astronaut walking on Mars next to a Starship rocket, realistic', 5, 'examples/astronaut-mars.jpg'], #['Nuclear explosion broccoli', 4, 'examples/nuclear-broccoli.jpg'], #['Dali painting of WALL·E', 5, 'examples/dali-walle.jpg'], #['Cleopatra checking her iPhone', 4, 'examples/cleopatra-iphone.jpg'], ], inputs=[ input_text, grid_size, #output_image ], examples_per_page=20 ) run_button.click( fn=run_model, inputs=[ input_text, grid_size, is_seamless, save_as_png, temperature, supercondition, top_k ], outputs=[ output_image ] ) demo.launch()