min_dalle / app.py
kaushikbar
min dall-e
c947eb8
import datetime
import gradio
import subprocess
from PIL import Image
import torch, torch.backends.cudnn, torch.backends.cuda
from min_dalle import MinDalle
from emoji import demojize
import string
def filename_from_text(text: str) -> str:
text = demojize(text, delimiters=['', ''])
text = text.lower().encode('ascii', errors='ignore').decode()
allowed_chars = string.ascii_lowercase + ' '
text = ''.join(i for i in text.lower() if i in allowed_chars)
text = text[:64]
text = '-'.join(text.strip().split())
if len(text) == 0: text = 'blank'
return text
def log_gpu_memory():
print("Date:{}, GPU memory:{}".format(str(datetime.datetime.now()), subprocess.check_output('nvidia-smi').decode('utf-8')))
log_gpu_memory()
model = MinDalle(
is_mega=True,
is_reusable=True,
device='cuda',
dtype=torch.float32
)
log_gpu_memory()
def run_model(
text: str,
grid_size: int,
is_seamless: bool,
save_as_png: bool,
temperature: float,
supercondition: str,
top_k: str
) -> str:
torch.set_grad_enabled(False)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
print("Date:{}".format(str(datetime.datetime.now())))
print('text:', text)
print('grid_size:', grid_size)
print('is_seamless:', is_seamless)
print('temperature:', temperature)
print('supercondition:', supercondition)
print('top_k:', top_k)
try:
temperature = float(temperature)
assert(temperature > 1e-6)
except:
raise Exception('Temperature must be a positive nonzero number')
try:
grid_size = int(grid_size)
assert(grid_size <= 5)
assert(grid_size >= 1)
except:
raise Exception('Grid size must be between 1 and 5')
try:
top_k = int(top_k)
assert(top_k <= 16384)
assert(top_k >= 1)
except:
raise Exception('Top k must be between 1 and 16384')
with torch.no_grad():
image = model.generate_image(
text = text,
seed = -1,
grid_size = grid_size,
is_seamless = bool(is_seamless),
temperature = temperature,
supercondition_factor = float(supercondition),
top_k = top_k,
is_verbose = True
)
log_gpu_memory()
ext = 'png' if bool(save_as_png) else 'jpg'
filename = filename_from_text(text)
image_path = '{}.{}'.format(filename, ext)
image.save(image_path)
return image_path
demo = gradio.Blocks(analytics_enabled=True)
with demo:
with gradio.Row():
with gradio.Column():
input_text = gradio.Textbox(
label='Input Text',
value='Moai statue giving a TED Talk',
lines=3
)
run_button = gradio.Button(value='Generate Image').style(full_width=True)
'''
output_image = gradio.Image(
value='examples/moai-statue.jpg',
label='Output Image',
type='file',
interactive=False
)
'''
with gradio.Column():
gradio.Markdown('## Settings')
with gradio.Row():
grid_size = gradio.Slider(
label='Grid Size',
value=5,
minimum=1,
maximum=5,
step=1
)
save_as_png = gradio.Checkbox(
label='Output PNG',
value=False
)
is_seamless = gradio.Checkbox(
label='Seamless',
value=False
)
gradio.Markdown('#### Advanced')
with gradio.Row():
temperature = gradio.Number(
label='Temperature',
value=1
)
top_k = gradio.Dropdown(
label='Top-k',
choices=[str(2 ** i) for i in range(15)],
value='128'
)
supercondition = gradio.Dropdown(
label='Super Condition',
choices=[str(2 ** i) for i in range(2, 7)],
value='16'
)
gradio.Markdown(
"""
####
- **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image.
- **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds.
- **Seamless**: Tile images in image token space instead of pixel space.
- **Temperature**: High temperature increases the probability of sampling low scoring image tokens.
- **Top-k**: Each image token is sampled from the top-k scoring tokens.
- **Super Condition**: Higher values can result in better agreement with the text.
"""
)
gradio.Examples(
examples=[
#['Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', 3, 'examples/rusty-iron-man.jpg'],
#['Moai statue giving a TED Talk', 5, 'examples/moai-statue.jpg'],
#['Court sketch of Godzilla on trial', 5, 'examples/godzilla-trial.jpg'],
#['lofi nuclear war to relax and study to', 5, 'examples/lofi-nuclear-war.jpg'],
#['Karl Marx slimed at Kids Choice Awards', 4, 'examples/marx-slimed.jpg'],
#['Scientists trying to rhyme orange with banana', 4, 'examples/scientists-rhyme.jpg'],
#['Jesus turning water into wine on Americas Got Talent', 5, 'examples/jesus-talent.jpg'],
#['Elmo in a street riot throwing a Molotov cocktail, hyperrealistic', 5, 'examples/elmo-riot.jpg'],
#['Trail cam footage of gollum eating watermelon', 4, 'examples/gollum.jpg'],
#['Funeral at Whole Foods', 4, 'examples/funeral-whole-foods.jpg'],
#['Singularity, hyperrealism', 5, 'examples/singularity.jpg'],
#['Astronaut riding a horse hyperrealistic', 5, 'examples/astronaut-horse.jpg'],
['Astronaut riding a horse hyperrealistic', 1],
#['An astronaut walking on Mars next to a Starship rocket, realistic', 5, 'examples/astronaut-mars.jpg'],
#['Nuclear explosion broccoli', 4, 'examples/nuclear-broccoli.jpg'],
#['Dali painting of WALL·E', 5, 'examples/dali-walle.jpg'],
#['Cleopatra checking her iPhone', 4, 'examples/cleopatra-iphone.jpg'],
],
inputs=[
input_text,
grid_size,
#output_image
],
examples_per_page=20
)
run_button.click(
fn=run_model,
inputs=[
input_text,
grid_size,
is_seamless,
save_as_png,
temperature,
supercondition,
top_k
],
outputs=[
output_image
]
)
demo.launch()