Spaces:
Running
Running
import os | |
import json | |
import time | |
import kiui | |
from typing import List | |
import replicate | |
import subprocess | |
from constants import OFFLINE_GIF_DIR | |
# os.environ("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9") | |
class BaseModelWorker: | |
def __init__(self, | |
model_name: str, | |
i2s_model: bool, | |
online_model: bool, | |
model_api: str = None | |
): | |
self.model_name = model_name | |
self.i2s_model = i2s_model | |
self.online_model = online_model | |
self.model_api = model_api | |
self.urls_json = None | |
urls_json_path = os.path.join(OFFLINE_GIF_DIR, f"{model_name}.json") | |
if os.path.exists(urls_json_path): | |
with open(urls_json_path, 'r') as f: | |
self.urls_json = json.load(f) | |
def check_online(self) -> bool: | |
if self.online_model and not self.model: | |
return True | |
else: | |
return False | |
def load_offline(self, offline: bool, offline_idx): | |
## offline | |
if offline and str(offline_idx) in self.urls_json.keys(): | |
return self.urls_json[str(offline_idx)] | |
else: | |
return None | |
def inference(self, prompt): | |
pass | |
def render(self, shape, rgb_on=True, normal_on=True): | |
pass | |
class HuggingfaceApiWorker(BaseModelWorker): | |
def __init__( | |
self, | |
model_name: str, | |
i2s_model: bool, | |
online_model: bool, | |
model_api: str, | |
): | |
super().__init__( | |
model_name, | |
i2s_model, | |
online_model, | |
model_api, | |
) | |
class PointE_Worker(BaseModelWorker): | |
def __init__(self, | |
model_name: str, | |
i2s_model: bool, | |
online_model: bool, | |
model_api: str): | |
super().__init__(model_name, i2s_model, online_model, model_api) | |
class TriplaneGaussian(BaseModelWorker): | |
def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None): | |
super().__init__(model_name, i2s_model, online_model, model_api) | |
class LGM_Worker(BaseModelWorker): | |
def __init__(self, | |
model_name: str, | |
i2s_model: bool, | |
online_model: bool, | |
model_api: str = "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2", | |
): | |
super().__init__(model_name, i2s_model, online_model, model_api) | |
self.model_client = replicate.Client(api_token=REPLICATE_API_TOKEN) | |
def inference(self, image): | |
output = self.model_client.run( | |
self.model_api, | |
input={"input_image": image} | |
) | |
#=> .mp4 .ply | |
return output[1] | |
def render(self, shape): | |
mesh = Gau2Mesh_client.run(shape) | |
path_normal = "" | |
cmd_normal = f"python -m ..kiuikit.kiui.render {mesh} --save {path_normal} \ | |
--wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode normal" | |
subprocess.run(cmd_normal, shell=True, check=True) | |
path_rgb = "" | |
cmd_rgb = f"python -m ..kiuikit.kiui.render {mesh} --save {path_rgb} \ | |
--wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode rgb" | |
subprocess.run(cmd_rgb, shell=True, check=True) | |
return path_normal, path_rgb | |
class V3D_Worker(BaseModelWorker): | |
def __init__(self, | |
model_name: str, | |
i2s_model: bool, | |
online_model: bool, | |
model_api: str = None): | |
super().__init__(model_name, i2s_model, online_model, model_api) | |
# model = 'LGM' | |
# # model = 'TriplaneGaussian' | |
# folder = 'glbs_full' | |
# form = 'glb' | |
# pose = '+z' | |
# pair = ('OpenLRM', 'meshes', 'obj', '-y') | |
# pair = ('TriplaneGaussian', 'glbs_full', 'glb', '-y') | |
# pair = ('LGM', 'glbs_full', 'glb', '+z') | |
if __name__=="__main__": | |
# input = { | |
# "input_image": "https://replicate.delivery/pbxt/KN0hQI9pYB3NOpHLqktkkQIblwpXt0IG7qI90n5hEnmV9kvo/bird_rgba.png", | |
# } | |
# print("Start...") | |
# model_client = replicate.Client(api_token=REPLICATE_API_TOKEN) | |
# output = model_client.run( | |
# "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2", | |
# input=input | |
# ) | |
# print("output: ", output) | |
#=> ['https://replicate.delivery/pbxt/toffawxRE3h6AUofI9sPtiAsoYI0v73zuGDZjZWBWAPzHKSlA/gradio_output.mp4', 'https://replicate.delivery/pbxt/oSn1XPfoJuw2UKOUIAue2iXeT7aXncVjC4QwHKU5W5x0HKSlA/gradio_output.ply'] | |
output = ['https://replicate.delivery/pbxt/RPSTEes37lzAJav3jy1lPuzizm76WGU4IqDcFcAMxhQocjUJA/gradio_output.mp4', 'https://replicate.delivery/pbxt/2Vy8yrPO3PYiI1YJBxPXAzryR0SC0oyqW3XKPnXiuWHUuRqE/gradio_output.ply'] | |
to_mesh_client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True) | |
mesh = to_mesh_client.predict(output[1], api_name="/run") | |
print(mesh) |