Spaces:
Running
on
A100
Running
on
A100
import json | |
import os | |
import subprocess | |
import time | |
import uuid | |
import zipfile | |
from dataclasses import fields | |
from urllib.request import urlretrieve | |
import gradio as gr | |
import transformers | |
from legogpt.models import LegoGPT, LegoGPTConfig | |
def setup(): | |
# Set up Gurobi licence | |
licence_filename = 'gurobi.lic' | |
licence_lines = [] | |
for secret_name in ['WLSACCESSID', 'WLSSECRET', 'LICENSEID']: | |
secret = os.environ.get(secret_name) | |
if not secret: | |
raise ValueError(f'Env variable {secret_name} not found. Please set it in the Hugging Face Space settings.') | |
licence_lines.append(f'{secret_name}={secret}\n') | |
with open(licence_filename, 'w') as f: | |
f.writelines(licence_lines) | |
os.environ['GRB_LICENSE_FILE'] = os.path.abspath(licence_filename) | |
# Download LDraw part library and set LDraw library path | |
ldraw_zip_url = 'https://library.ldraw.org/library/updates/complete.zip' | |
ldraw_zip_filename = 'complete.zip' | |
urlretrieve(ldraw_zip_url, ldraw_zip_filename) | |
with zipfile.ZipFile(ldraw_zip_filename) as zip_ref: | |
zip_ref.extractall() | |
os.environ['LDRAW_LIBRARY_PATH'] = os.path.abspath('ldraw') | |
def main(): | |
if os.environ.get('IS_HF_SPACE') == '1': | |
print('Running in Hugging Face Space, setting up environment...') | |
setup() | |
model_cfg = LegoGPTConfig(max_regenerations=10) | |
model = LegoGPT(model_cfg) | |
def generate_lego( | |
prompt: str, | |
temperature: float | None, | |
seed: int | None, | |
max_bricks: int | None, | |
max_brick_rejections: int | None, | |
max_regenerations: int | None, | |
): | |
# Set model parameters | |
if temperature is not None: model.temperature = temperature | |
if max_bricks is not None: model.max_bricks = max_bricks | |
if max_brick_rejections is not None: model.max_brick_rejections = max_brick_rejections | |
if max_regenerations is not None: model.max_regenerations = max_regenerations | |
if seed is not None: transformers.set_seed(seed) | |
# Generate LEGO | |
print(f'Generating LEGO for prompt: "{prompt}"') | |
start_time = time.time() | |
output = model(prompt) | |
# Write output LDR to file | |
output_dir = os.path.abspath('out') | |
output_uuid = str(uuid.uuid4()) | |
os.makedirs(output_dir, exist_ok=True) | |
ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr') | |
with open(ldr_filename, 'w') as f: | |
f.write(output['lego'].to_ldr()) | |
print(f'Finished generation in {time.time() - start_time:.1f}s!') | |
# Render LEGO model to image | |
print('Rendering image...') | |
start_time = time.time() | |
img_filename = os.path.join(output_dir, f'{output_uuid}.png') | |
subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename], | |
check=True) # Run render as a subprocess to prevent issues with Blender | |
print(f'Finished rendering in {time.time() - start_time:.1f}s!') | |
return img_filename, output['lego'] | |
# Define inputs and outputs | |
in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.') | |
in_temperature = gr.Slider(0.0, 2.0, value=model_cfg.temperature, step=0.01, | |
label='Temperature', info=get_help_string('temperature')) | |
in_seed = gr.Number(value=42, label='Seed', info='Random seed for generation.', precision=0, step=1) | |
in_bricks = gr.Number(value=model_cfg.max_bricks, label='Max bricks', info=get_help_string('max_bricks'), | |
precision=0, minimum=1, step=1) | |
in_rejections = gr.Number(value=model_cfg.max_brick_rejections, label='Max brick rejections', | |
info=get_help_string('max_brick_rejections'), precision=0, minimum=0, step=1) | |
in_regenerations = gr.Number(value=model_cfg.max_regenerations, label='Max regenerations', | |
info=get_help_string('max_regenerations'), precision=0, minimum=0, step=1) | |
out_img = gr.Image(label='Output image', format='png') | |
out_txt = gr.Textbox(label='Output LEGO bricks', lines=5, max_lines=5, show_copy_button=True, | |
info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a ' | |
'1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).') | |
# Define Gradio interface | |
demo = gr.Interface( | |
fn=generate_lego, | |
title='LegoGPT Demo', | |
description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n' | |
'The model is restricted to creating structures on a 20x20x20 grid. It was trained on a dataset of 21 object categories: ' | |
'*basket, bed, bench, birdhouse, bookshelf, bottle, bowl, bus, camera, car, chair, guitar, jar, mug, piano, pot, sofa, table, tower, train, vessel.* ' | |
'Performance on prompts from outside these categories may be limited. This demo does not include texturing or coloring.', | |
inputs=[in_prompt], | |
additional_inputs=[in_temperature, in_seed, in_bricks, in_rejections, in_regenerations], | |
outputs=[out_img, out_txt], | |
flagging_mode='never', | |
) | |
with demo: | |
with gr.Row(): | |
examples = get_examples() | |
dummy_name = gr.Textbox(visible=False, label='Name') | |
dummy_out_img = gr.Image(visible=False, label='Result') | |
gr.Examples( | |
examples=[[name, example['prompt'], example['temperature'], example['seed'], example['output_img']] | |
for name, example in examples.items()], | |
inputs=[dummy_name, in_prompt, in_temperature, in_seed, dummy_out_img], | |
outputs=[out_img, out_txt], | |
fn=lambda *args: (args[-1], examples[args[0]]['output_txt']), | |
run_on_click=True, | |
) | |
demo.launch(share=True) | |
def get_help_string(field_name: str) -> str: | |
""" | |
:param field_name: Name of a field in LegoGPTConfig. | |
:return: Help string for the field. | |
""" | |
data_fields = fields(LegoGPTConfig) | |
name_field = next(f for f in data_fields if f.name == field_name) | |
return name_field.metadata['help'] | |
def get_examples(example_dir: str = os.path.abspath('examples')) -> dict[str, dict[str, str]]: | |
examples_file = os.path.join(example_dir, 'examples.json') | |
with open(examples_file) as f: | |
examples = json.load(f) | |
for example in examples.values(): | |
example['output_img'] = os.path.join(example_dir, example['output_img']) | |
return examples | |
if __name__ == '__main__': | |
main() | |