CraftsMan3D / server.py
wyysf's picture
fix bug
b1311b0
import argparse
import base64
import os
from datetime import datetime
import traceback
import trimesh
import torch
from craftsman import CraftsManPipeline
CURRENT_DIR = f'/tmp/native3d_server/{os.getpid()}'
os.makedirs(CURRENT_DIR, exist_ok=True)
def parse_parameters():
parser = argparse.ArgumentParser("native3d")
parser.add_argument('--host', default="0.0.0.0", type=str)
parser.add_argument('--port', default=12345, type=int)
return parser.parse_args()
# -------------------- fastapi --------------------
from typing import Optional
from pydantic import BaseModel, Field
class Native3DRequestV1(BaseModel):
image_path: str # input image path
mesh_path: str # output mesh path, support glb or obj in clean dir
class Native3DResponseV1(BaseModel):
pass
class Native3DRequestV2(BaseModel):
image_bytes: str # input image bytes(base64)
mesh_type: str # output mesh type, support glb or obj
class Native3DResponseV2(BaseModel):
mesh_bytes: str # output mesh bytes(base64)
if __name__=="__main__":
parse_args = parse_parameters()
# prepare models
pipeline = CraftsManPipeline.from_pretrained("/home/super/Desktop/8TDisk/weiyu/CraftsMan_gradio/ckpts/craftsman-v1-5", device="cuda:0", torch_dtype=torch.float32)
# -------------------- fastapi --------------------
from fastapi import FastAPI, Request
import requests
app = FastAPI()
@app.post("/native3d_v1", response_model=Native3DResponseV1)
async def native3d(request: Request, image_to_mesh_request: Native3DRequestV1):
try:
print(f"image_to_mesh_request = {image_to_mesh_request}")
mesh = pipeline(image_to_mesh_request.image_path).meshes[0]
os.makedirs(os.path.dirname(os.path.abspath(image_to_mesh_request.mesh_path)), exist_ok=True)
mesh.export(image_to_mesh_request.mesh_path)
except Exception as e:
traceback.print_exc()
print(f"generate_model error: {e}")
return Native3DResponseV1()
@app.post("/native3d_v2", response_model=Native3DResponseV2)
async def native3d(request: Request, image_to_mesh_request: Native3DRequestV2):
try:
# print(f"image_to_mesh_request = {image_to_mesh_request}")
mesh_type = image_to_mesh_request.mesh_type
assert mesh_type in ['obj', 'glb']
task_id = datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f') + '-' + 'native3d'
current_dir = os.path.join(CURRENT_DIR, task_id)
os.makedirs(current_dir, exist_ok=True)
image_path = os.path.join(current_dir, 'input_image.png')
with open(image_path, 'wb') as f:
f.write(base64.b64decode(image_to_mesh_request.image_bytes))
mesh_path = os.path.join(current_dir, f'output_mesh.{mesh_type}')
import time
start = time.time()
# mesh = pipeline(image_path).meshes[0]
# mesh = pipeline(image_path, mc_depth=7, num_inference_steps=50).meshes[0]
mesh = pipeline(image_path).meshes[0]
print(f"Time: {time.time() - start}s")
os.makedirs(os.path.dirname(os.path.abspath(mesh_path)), exist_ok=True)
mesh.visual = trimesh.visual.TextureVisuals(
material=trimesh.visual.material.PBRMaterial(
baseColorFactor=(255, 255, 255), main_color=(255, 255, 255), metallicFactor=0.05, roughnessFactor=1.0
)
)
mesh.export(mesh_path)
with open(mesh_path, 'rb') as f:
mesh_bytes = f.read()
except Exception as e:
traceback.print_exc()
print(f"generate_model error: {e}")
return Native3DResponseV2(mesh_bytes=base64.b64encode(mesh_bytes).decode('utf-8'))
@app.get("/health")
async def health():
return {"status": "OK"}
import uvicorn
uvicorn.run(app, host=parse_args.host, port=parse_args.port)