#!/usr/bin/env python from __future__ import annotations import os import datetime import pathlib import shlex import subprocess import sys from typing import Generator, Optional import trimesh import spaces import gradio as gr # from model import Model sys.path.append('TEXTurePaper') from src.configs.train_config import GuideConfig, LogConfig, TrainConfig from src.training.trainer import TEXTure class Model: def __init__(self): self.max_num_faces = 100000 def load_config(self, shape_path: str, text: str, seed: int, guidance_scale: float) -> TrainConfig: text += ', {} view' log = LogConfig(exp_name=self.gen_exp_name()) guide = GuideConfig(text=text) guide.background_img = 'TEXTurePaper/textures/brick_wall.png' guide.shape_path = 'TEXTurePaper/shapes/spot_triangulated.obj' config = TrainConfig(log=log, guide=guide) config.guide.shape_path = shape_path config.optim.seed = seed config.guide.guidance_scale = guidance_scale return config def gen_exp_name(self) -> str: now = datetime.datetime.now() return now.strftime('%Y-%m-%d-%H-%M-%S') def check_num_faces(self, path: str) -> bool: with open(path) as f: lines = [line for line in f.readlines() if line.startswith('f')] return len(lines) <= self.max_num_faces def zip_results(self, exp_dir: pathlib.Path) -> str: mesh_dir = exp_dir / 'mesh' out_path = f'{exp_dir.name}.zip' subprocess.run(shlex.split(f'zip -r {out_path} {mesh_dir}')) return out_path def run( self, shape_path: str, text: str, seed: int, guidance_scale: float ) -> Generator[tuple[list[str], Optional[str], Optional[str], str], None, None]: if not shape_path.endswith('.obj'): raise gr.Error('The input file is not .obj file.') if not self.check_num_faces(shape_path): raise gr.Error('The number of faces is over 100,000.') config = self.load_config(shape_path, text, seed, guidance_scale) trainer = TEXTure(config) trainer.mesh_model.train() total_steps = len(trainer.dataloaders['train']) for step, data in enumerate(trainer.dataloaders['train'], start=1): trainer.paint_step += 1 trainer.paint_viewpoint(data) trainer.evaluate(trainer.dataloaders['val'], trainer.eval_renders_path) trainer.mesh_model.train() sample_image_dir = config.log.exp_dir / 'vis' / 'eval' sample_image_paths = sorted( sample_image_dir.glob(f'step_{trainer.paint_step:05d}_*.jpg')) sample_image_paths = [ path.as_posix() for path in sample_image_paths ] yield sample_image_paths, None, None, f'{step}/{total_steps}' trainer.mesh_model.change_default_to_median() save_dir = trainer.exp_path / 'mesh' save_dir.mkdir(exist_ok=True, parents=True) trainer.mesh_model.export_mesh(save_dir) model_path = save_dir / 'mesh.obj' mesh = trimesh.load(model_path) mesh_path = save_dir / 'mesh.glb' mesh.export(mesh_path, file_type='glb') zip_path = self.zip_results(config.log.exp_dir) yield sample_image_paths, mesh_path.as_posix(), zip_path, 'Done!' @spaces.GPU def main(): DESCRIPTION = '''# [TEXTure](https://github.com/TEXTurePaper/TEXTurePaper) - This demo only accepts as input `.obj` files with less than 100,000 faces. - Inference takes about 10 minutes on a T4 GPU. ''' if (SPACE_ID := os.getenv('SPACE_ID')) is not None: DESCRIPTION += f'\n
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
' model = Model() with gr.Blocks(css='style.css') as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): input_shape = gr.Model3D(label='Input 3D mesh') text = gr.Text(label='Text') seed = gr.Slider(label='Seed', minimum=0, maximum=100000, value=3, step=1) guidance_scale = gr.Slider(label='Guidance scale', minimum=0, maximum=50, value=7.5, step=0.1) run_button = gr.Button('Run') with gr.Column(): progress_text = gr.Text(label='Progress') with gr.Tabs(): with gr.TabItem(label='Images from each viewpoint'): viewpoint_images = gr.Gallery(show_label=False, columns=4) with gr.TabItem(label='Result 3D model'): result_3d_model = gr.Model3D(show_label=False) with gr.TabItem(label='Output mesh file'): output_file = gr.File(show_label=False) with gr.Row(): examples = [ ['shapes/dragon1.obj', 'a photo of a dragon', 0, 7.5], ['shapes/dragon2.obj', 'a photo of a dragon', 0, 7.5], ['shapes/eagle.obj', 'a photo of an eagle', 0, 7.5], ['shapes/napoleon.obj', 'a photo of Napoleon Bonaparte', 3, 7.5], ['shapes/nascar.obj', 'A next gen nascar', 2, 10], ] gr.Examples(examples=examples, inputs=[ input_shape, text, seed, guidance_scale, ], outputs=[ result_3d_model, output_file, ], cache_examples=False) run_button.click(fn=model.run, inputs=[ input_shape, text, seed, guidance_scale, ], outputs=[ viewpoint_images, result_3d_model, output_file, progress_text, ]) demo.queue(max_size=5).launch(debug=True) if __name__ == '__main__': main()