Spaces:
Running
on
A100
Running
on
A100
import os | |
import subprocess | |
from dataclasses import fields | |
import gradio as gr | |
import transformers | |
from legogpt.models import LegoGPT, LegoGPTConfig | |
def main(): | |
model_cfg = LegoGPTConfig() | |
model = LegoGPT(LegoGPTConfig()) | |
default_seed = 42 | |
def generate_lego( | |
prompt: str, | |
temperature: float | None, | |
seed: int | None, | |
max_bricks: int | None, | |
max_brick_rejections: int | None, | |
max_regenerations: int | None, | |
): | |
# Set model parameters | |
if temperature is not None: model.temperature = temperature | |
if max_bricks is not None: model.max_bricks = max_bricks | |
if max_brick_rejections is not None: model.max_brick_rejections = max_brick_rejections | |
if max_regenerations is not None: model.max_regenerations = max_regenerations | |
if seed is not None: transformers.set_seed(seed) | |
# Generate LEGO | |
output = model(prompt) | |
# Render results and write to files | |
render_lego_filename = 'render_lego.py' | |
ldr_filename = os.path.abspath('output.ldr') | |
img_filename = os.path.abspath('output.png') | |
with open(ldr_filename, 'w') as f: | |
f.write(output['lego'].to_ldr()) | |
# Run render as a subprocess to prevent issues with Blender | |
subprocess.run(['uv', 'run', render_lego_filename, '--in_file', ldr_filename, '--out_file', img_filename], | |
capture_output=True, check=True) | |
return img_filename, output['lego'] | |
demo = gr.Interface( | |
fn=generate_lego, | |
title='LegoGPT Demo', | |
description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.', | |
inputs=[ | |
gr.Textbox(placeholder='Enter a prompt to generate a LEGO model.'), | |
], | |
additional_inputs=[ | |
gr.Slider(0.0, 2.0, value=model_cfg.temperature, info=get_help_string('temperature'), step=0.01), | |
gr.Number(value=42, info='Random seed for generation.', precision=0, step=1), | |
gr.Number(value=model_cfg.max_bricks, info=get_help_string('max_bricks'), | |
precision=0, minimum=1, step=1), | |
gr.Number(value=model_cfg.max_brick_rejections, info=get_help_string('max_brick_rejections'), | |
precision=0, minimum=0, step=1), | |
gr.Number(value=model_cfg.max_regenerations, info=get_help_string('max_regenerations'), | |
precision=0, minimum=0, step=1), | |
], | |
outputs=[ | |
gr.Image(label='output_img'), | |
gr.Textbox(label='output_txt', lines=5, max_lines=5, show_copy_button=True, | |
info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a ' | |
'1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).'), | |
], | |
examples=[[prompt, model_cfg.temperature, default_seed] for prompt in get_example_prompts()], | |
# cache_examples=True, | |
flagging_mode='never', | |
) | |
demo.launch(share=True) | |
def get_help_string(field_name: str) -> str: | |
""" | |
:param field_name: Name of a field in LegoGPTConfig. | |
:return: Help string for the field. | |
""" | |
data_fields = fields(LegoGPTConfig) | |
name_field = next(f for f in data_fields if f.name == field_name) | |
return name_field.metadata['help'] | |
def get_example_prompts() -> list[str]: | |
example_prompts_file = 'example_prompts.txt' | |
with open(example_prompts_file) as f: | |
example_prompts = list(map(lambda x: x.strip(), f.readlines())) | |
return example_prompts | |
if __name__ == '__main__': | |
main() | |