File size: 3,769 Bytes
6d7e55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import subprocess
from dataclasses import fields

import gradio as gr
import transformers
from legogpt.models import LegoGPT, LegoGPTConfig


def main():
    model_cfg = LegoGPTConfig()
    model = LegoGPT(LegoGPTConfig())
    default_seed = 42

    def generate_lego(
            prompt: str,
            temperature: float | None,
            seed: int | None,
            max_bricks: int | None,
            max_brick_rejections: int | None,
            max_regenerations: int | None,
    ):
        # Set model parameters
        if temperature is not None: model.temperature = temperature
        if max_bricks is not None: model.max_bricks = max_bricks
        if max_brick_rejections is not None: model.max_brick_rejections = max_brick_rejections
        if max_regenerations is not None: model.max_regenerations = max_regenerations
        if seed is not None: transformers.set_seed(seed)

        # Generate LEGO
        output = model(prompt)

        # Render results and write to files
        render_lego_filename = 'render_lego.py'
        ldr_filename = os.path.abspath('output.ldr')
        img_filename = os.path.abspath('output.png')
        with open(ldr_filename, 'w') as f:
            f.write(output['lego'].to_ldr())

        # Run render as a subprocess to prevent issues with Blender
        subprocess.run(['uv', 'run', render_lego_filename, '--in_file', ldr_filename, '--out_file', img_filename],
                       capture_output=True, check=True)

        return img_filename, output['lego']

    demo = gr.Interface(
        fn=generate_lego,
        title='LegoGPT Demo',
        description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.',
        inputs=[
            gr.Textbox(placeholder='Enter a prompt to generate a LEGO model.'),
        ],
        additional_inputs=[
            gr.Slider(0.0, 2.0, value=model_cfg.temperature, info=get_help_string('temperature'), step=0.01),
            gr.Number(value=42, info='Random seed for generation.', precision=0, step=1),
            gr.Number(value=model_cfg.max_bricks, info=get_help_string('max_bricks'),
                      precision=0, minimum=1, step=1),
            gr.Number(value=model_cfg.max_brick_rejections, info=get_help_string('max_brick_rejections'),
                      precision=0, minimum=0, step=1),
            gr.Number(value=model_cfg.max_regenerations, info=get_help_string('max_regenerations'),
                      precision=0, minimum=0, step=1),
        ],
        outputs=[
            gr.Image(label='output_img'),
            gr.Textbox(label='output_txt', lines=5, max_lines=5, show_copy_button=True,
                       info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a '
                            '1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).'),
        ],
        examples=[[prompt, model_cfg.temperature, default_seed] for prompt in get_example_prompts()],
        # cache_examples=True,
        flagging_mode='never',
    )
    demo.launch(share=True)


def get_help_string(field_name: str) -> str:
    """
    :param field_name: Name of a field in LegoGPTConfig.
    :return: Help string for the field.
    """
    data_fields = fields(LegoGPTConfig)
    name_field = next(f for f in data_fields if f.name == field_name)
    return name_field.metadata['help']


def get_example_prompts() -> list[str]:
    example_prompts_file = 'example_prompts.txt'
    with open(example_prompts_file) as f:
        example_prompts = list(map(lambda x: x.strip(), f.readlines()))
    return example_prompts


if __name__ == '__main__':
    main()