import os
import json
import time
import kiui
from typing import List
import replicate
import subprocess

import requests
from gradio_client import Client
# from .client import Gau2Mesh_client
from constants import REPLICATE_API_TOKEN, LOG_SERVER, GIF_SERVER
# os.environ("REPLICATE_API_TOKEN", "yourKey")

class BaseModelWorker:
    def __init__(self,
                 model_name: str,
                 i2s_model: bool, 
                 online_model: bool,
                 model_api: str = None
                 ):
        self.model_name = model_name
        self.i2s_model = i2s_model
        self.online_model = online_model
        self.model_api = model_api
        self.urls_json = None
        
        # urls_json_path = os.path.join(OFFLINE_GIF_DIR, f"{model_name}.json")
        # if os.path.exists(urls_json_path):
        #     with open(urls_json_path, 'r') as f:
        #         self.urls_json = json.load(f)

    def check_online(self) -> bool:
        if self.online_model and not self.model:
            return True
        else:
            return False
    
    def load_offline(self, offline_idx):
        ## offline 
        # if offline and str(offline_idx) in self.urls_json.keys():
        # #     return self.urls_json[str(offline_idx)]
        # else:
        #     return None
        galley = "image2shape" if self.i2s_model else "text2shape"
        rgb_name = f"{galley}_{self.model_name}_{offline_idx}_rgb.gif"
        normal_name = f"{galley}_{self.model_name}_{offline_idx}_normal.gif"
        geo_name = f"{galley}_{self.model_name}_{offline_idx}_geo.gif"

        rgb_url = f"{GIF_SERVER}/{rgb_name}"
        normal_url = f"{GIF_SERVER}/{normal_name}"
        geo_url = f"{GIF_SERVER}/{geo_name}"
        return {'rgb': rgb_url, 
                'normal': normal_url,
                'geo': geo_url}

    def inference(self, prompt):
        pass

    def render(self, shape, rgb_on=True, normal_on=True):
        pass

class HuggingfaceApiWorker(BaseModelWorker):
    def __init__(
            self,
            model_name: str,
            i2s_model: bool, 
            online_model: bool,
            model_api: str,
    ):
        super().__init__(
            model_name,
            i2s_model, 
            online_model,
            model_api,
        )

# class PointE_Worker(BaseModelWorker):
#     def __init__(self, 
#                  model_name: str, 
#                  i2s_model: bool, 
#                  online_model: bool, 
#                  model_api: str):
#         super().__init__(model_name, i2s_model, online_model, model_api)

# class TriplaneGaussian(BaseModelWorker):
#     def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None):
#         super().__init__(model_name, i2s_model, online_model, model_api)


# class LGM_Worker(BaseModelWorker):
#     def __init__(self, 
#                  model_name: str, 
#                  i2s_model: bool, 
#                  online_model: bool, 
#                  model_api: str = "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
#     ):
#         super().__init__(model_name, i2s_model, online_model, model_api)
#         self.model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
    
#     def inference(self, image):
        
#         output = self.model_client.run(
#             self.model_api,
#             input={"input_image": image}
#         )
#         #=> .mp4 .ply
#         return output[1]

#     def render(self, shape):
#         mesh = Gau2Mesh_client.run(shape)

#         path_normal = ""
#         cmd_normal = f"python -m ..kiuikit.kiui.render {mesh} --save {path_normal} \
#             --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode normal"
#         subprocess.run(cmd_normal, shell=True, check=True)

#         path_rgb = ""
#         cmd_rgb = f"python -m ..kiuikit.kiui.render {mesh} --save {path_rgb} \
#             --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode rgb"
#         subprocess.run(cmd_rgb, shell=True, check=True)
        
#         return path_normal, path_rgb

# class V3D_Worker(BaseModelWorker):
#     def __init__(self, 
#                 model_name: str, 
#                 i2s_model: bool, 
#                 online_model: bool, 
#                 model_api: str = None):
#         super().__init__(model_name, i2s_model, online_model, model_api)


# model = 'LGM'
# # model = 'TriplaneGaussian'
# folder = 'glbs_full'
# form = 'glb'
# pose = '+z'

# pair = ('OpenLRM', 'meshes', 'obj', '-y')
# pair = ('TriplaneGaussian', 'glbs_full', 'glb', '-y')
# pair = ('LGM', 'glbs_full', 'glb', '+z')


if __name__=="__main__":
    # input = {
    # "input_image": "https://replicate.delivery/pbxt/KN0hQI9pYB3NOpHLqktkkQIblwpXt0IG7qI90n5hEnmV9kvo/bird_rgba.png",
    # }
    # print("Start...")
    # model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
    # output = model_client.run(
    # "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
    # input=input
    # )
    # print("output: ", output)
    #=>  ['https://replicate.delivery/pbxt/toffawxRE3h6AUofI9sPtiAsoYI0v73zuGDZjZWBWAPzHKSlA/gradio_output.mp4', 'https://replicate.delivery/pbxt/oSn1XPfoJuw2UKOUIAue2iXeT7aXncVjC4QwHKU5W5x0HKSlA/gradio_output.ply']

    output = ['https://replicate.delivery/pbxt/RPSTEes37lzAJav3jy1lPuzizm76WGU4IqDcFcAMxhQocjUJA/gradio_output.mp4', 'https://replicate.delivery/pbxt/2Vy8yrPO3PYiI1YJBxPXAzryR0SC0oyqW3XKPnXiuWHUuRqE/gradio_output.ply']
    to_mesh_client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True)
    mesh = to_mesh_client.predict(output[1], api_name="/run")
    print(mesh)