LegoGPT-Demo / app.py
AvaLovelace's picture
Add app files
6d7e55e
raw
history blame
3.77 kB
import os
import subprocess
from dataclasses import fields
import gradio as gr
import transformers
from legogpt.models import LegoGPT, LegoGPTConfig
def main():
model_cfg = LegoGPTConfig()
model = LegoGPT(LegoGPTConfig())
default_seed = 42
def generate_lego(
prompt: str,
temperature: float | None,
seed: int | None,
max_bricks: int | None,
max_brick_rejections: int | None,
max_regenerations: int | None,
):
# Set model parameters
if temperature is not None: model.temperature = temperature
if max_bricks is not None: model.max_bricks = max_bricks
if max_brick_rejections is not None: model.max_brick_rejections = max_brick_rejections
if max_regenerations is not None: model.max_regenerations = max_regenerations
if seed is not None: transformers.set_seed(seed)
# Generate LEGO
output = model(prompt)
# Render results and write to files
render_lego_filename = 'render_lego.py'
ldr_filename = os.path.abspath('output.ldr')
img_filename = os.path.abspath('output.png')
with open(ldr_filename, 'w') as f:
f.write(output['lego'].to_ldr())
# Run render as a subprocess to prevent issues with Blender
subprocess.run(['uv', 'run', render_lego_filename, '--in_file', ldr_filename, '--out_file', img_filename],
capture_output=True, check=True)
return img_filename, output['lego']
demo = gr.Interface(
fn=generate_lego,
title='LegoGPT Demo',
description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.',
inputs=[
gr.Textbox(placeholder='Enter a prompt to generate a LEGO model.'),
],
additional_inputs=[
gr.Slider(0.0, 2.0, value=model_cfg.temperature, info=get_help_string('temperature'), step=0.01),
gr.Number(value=42, info='Random seed for generation.', precision=0, step=1),
gr.Number(value=model_cfg.max_bricks, info=get_help_string('max_bricks'),
precision=0, minimum=1, step=1),
gr.Number(value=model_cfg.max_brick_rejections, info=get_help_string('max_brick_rejections'),
precision=0, minimum=0, step=1),
gr.Number(value=model_cfg.max_regenerations, info=get_help_string('max_regenerations'),
precision=0, minimum=0, step=1),
],
outputs=[
gr.Image(label='output_img'),
gr.Textbox(label='output_txt', lines=5, max_lines=5, show_copy_button=True,
info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a '
'1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).'),
],
examples=[[prompt, model_cfg.temperature, default_seed] for prompt in get_example_prompts()],
# cache_examples=True,
flagging_mode='never',
)
demo.launch(share=True)
def get_help_string(field_name: str) -> str:
"""
:param field_name: Name of a field in LegoGPTConfig.
:return: Help string for the field.
"""
data_fields = fields(LegoGPTConfig)
name_field = next(f for f in data_fields if f.name == field_name)
return name_field.metadata['help']
def get_example_prompts() -> list[str]:
example_prompts_file = 'example_prompts.txt'
with open(example_prompts_file) as f:
example_prompts = list(map(lambda x: x.strip(), f.readlines()))
return example_prompts
if __name__ == '__main__':
main()