File size: 3,986 Bytes
b1311b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import base64
import os
from datetime import datetime
import traceback
import trimesh
import torch
from craftsman import CraftsManPipeline

CURRENT_DIR = f'/tmp/native3d_server/{os.getpid()}'
os.makedirs(CURRENT_DIR, exist_ok=True)

def parse_parameters():
    parser = argparse.ArgumentParser("native3d")
    parser.add_argument('--host', default="0.0.0.0", type=str)
    parser.add_argument('--port', default=12345, type=int)
    return parser.parse_args()
    
# -------------------- fastapi --------------------
from typing import Optional
from pydantic import BaseModel, Field

class Native3DRequestV1(BaseModel):
    image_path: str  # input image path
    mesh_path: str  # output mesh path, support glb or obj in clean dir

class Native3DResponseV1(BaseModel):
    pass

class Native3DRequestV2(BaseModel):
    image_bytes: str  # input image bytes(base64)
    mesh_type: str  # output mesh type, support glb or obj

class Native3DResponseV2(BaseModel):
    mesh_bytes: str  # output mesh bytes(base64)

if __name__=="__main__":
    parse_args = parse_parameters()

    # prepare models
    pipeline = CraftsManPipeline.from_pretrained("/home/super/Desktop/8TDisk/weiyu/CraftsMan_gradio/ckpts/craftsman-v1-5", device="cuda:0", torch_dtype=torch.float32)

    # -------------------- fastapi --------------------
    from fastapi import FastAPI, Request
    import requests
    app = FastAPI()

    @app.post("/native3d_v1", response_model=Native3DResponseV1)
    async def native3d(request: Request, image_to_mesh_request: Native3DRequestV1):
        try:
            print(f"image_to_mesh_request = {image_to_mesh_request}")
            mesh = pipeline(image_to_mesh_request.image_path).meshes[0]
            os.makedirs(os.path.dirname(os.path.abspath(image_to_mesh_request.mesh_path)), exist_ok=True)
            mesh.export(image_to_mesh_request.mesh_path)
        except Exception as e:
            traceback.print_exc()
            print(f"generate_model error: {e}")
        return Native3DResponseV1()

    @app.post("/native3d_v2", response_model=Native3DResponseV2)
    async def native3d(request: Request, image_to_mesh_request: Native3DRequestV2):
        try:
            # print(f"image_to_mesh_request = {image_to_mesh_request}")
            mesh_type = image_to_mesh_request.mesh_type
            assert mesh_type in ['obj', 'glb']
            task_id = datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f') + '-' + 'native3d'
            current_dir = os.path.join(CURRENT_DIR, task_id)
            os.makedirs(current_dir, exist_ok=True)
            image_path = os.path.join(current_dir, 'input_image.png')
            with open(image_path, 'wb') as f:
                f.write(base64.b64decode(image_to_mesh_request.image_bytes))
            mesh_path = os.path.join(current_dir, f'output_mesh.{mesh_type}')
            import time
            start = time.time()
            # mesh = pipeline(image_path).meshes[0]
            # mesh = pipeline(image_path, mc_depth=7, num_inference_steps=50).meshes[0]
            mesh = pipeline(image_path).meshes[0]
            print(f"Time: {time.time() - start}s")
            os.makedirs(os.path.dirname(os.path.abspath(mesh_path)), exist_ok=True)
            mesh.visual = trimesh.visual.TextureVisuals(
                material=trimesh.visual.material.PBRMaterial(
                    baseColorFactor=(255, 255, 255), main_color=(255, 255, 255), metallicFactor=0.05, roughnessFactor=1.0
                )
            )
            mesh.export(mesh_path)
            with open(mesh_path, 'rb') as f:
                mesh_bytes = f.read()
        except Exception as e:
            traceback.print_exc()
            print(f"generate_model error: {e}")
        return Native3DResponseV2(mesh_bytes=base64.b64encode(mesh_bytes).decode('utf-8'))

    @app.get("/health")
    async def health():
        return {"status": "OK"}

    import uvicorn
    uvicorn.run(app, host=parse_args.host, port=parse_args.port)