Spaces:
Running
on
A100
Running
on
A100
File size: 3,769 Bytes
6d7e55e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
import subprocess
from dataclasses import fields
import gradio as gr
import transformers
from legogpt.models import LegoGPT, LegoGPTConfig
def main():
model_cfg = LegoGPTConfig()
model = LegoGPT(LegoGPTConfig())
default_seed = 42
def generate_lego(
prompt: str,
temperature: float | None,
seed: int | None,
max_bricks: int | None,
max_brick_rejections: int | None,
max_regenerations: int | None,
):
# Set model parameters
if temperature is not None: model.temperature = temperature
if max_bricks is not None: model.max_bricks = max_bricks
if max_brick_rejections is not None: model.max_brick_rejections = max_brick_rejections
if max_regenerations is not None: model.max_regenerations = max_regenerations
if seed is not None: transformers.set_seed(seed)
# Generate LEGO
output = model(prompt)
# Render results and write to files
render_lego_filename = 'render_lego.py'
ldr_filename = os.path.abspath('output.ldr')
img_filename = os.path.abspath('output.png')
with open(ldr_filename, 'w') as f:
f.write(output['lego'].to_ldr())
# Run render as a subprocess to prevent issues with Blender
subprocess.run(['uv', 'run', render_lego_filename, '--in_file', ldr_filename, '--out_file', img_filename],
capture_output=True, check=True)
return img_filename, output['lego']
demo = gr.Interface(
fn=generate_lego,
title='LegoGPT Demo',
description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.',
inputs=[
gr.Textbox(placeholder='Enter a prompt to generate a LEGO model.'),
],
additional_inputs=[
gr.Slider(0.0, 2.0, value=model_cfg.temperature, info=get_help_string('temperature'), step=0.01),
gr.Number(value=42, info='Random seed for generation.', precision=0, step=1),
gr.Number(value=model_cfg.max_bricks, info=get_help_string('max_bricks'),
precision=0, minimum=1, step=1),
gr.Number(value=model_cfg.max_brick_rejections, info=get_help_string('max_brick_rejections'),
precision=0, minimum=0, step=1),
gr.Number(value=model_cfg.max_regenerations, info=get_help_string('max_regenerations'),
precision=0, minimum=0, step=1),
],
outputs=[
gr.Image(label='output_img'),
gr.Textbox(label='output_txt', lines=5, max_lines=5, show_copy_button=True,
info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a '
'1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).'),
],
examples=[[prompt, model_cfg.temperature, default_seed] for prompt in get_example_prompts()],
# cache_examples=True,
flagging_mode='never',
)
demo.launch(share=True)
def get_help_string(field_name: str) -> str:
"""
:param field_name: Name of a field in LegoGPTConfig.
:return: Help string for the field.
"""
data_fields = fields(LegoGPTConfig)
name_field = next(f for f in data_fields if f.name == field_name)
return name_field.metadata['help']
def get_example_prompts() -> list[str]:
example_prompts_file = 'example_prompts.txt'
with open(example_prompts_file) as f:
example_prompts = list(map(lambda x: x.strip(), f.readlines()))
return example_prompts
if __name__ == '__main__':
main()
|