from gradio.components import Component
import gradio as gr
import pandas as pd
from abc import ABC, abstractclassmethod
import inspect

class BaseTCOModel(ABC):
    # TO DO: Find way to specify which component should be used for computing cost
    def __setattr__(self, name, value):
        if isinstance(value, Component):
            self._components.append(value)
        self.__dict__[name] = value

    def __init__(self):
        super(BaseTCOModel, self).__setattr__("_components", [])
        self.use_case = None  

    def get_components(self) -> list[Component]:
        return self._components
    
    def get_components_for_cost_computing(self):
        return self.components_for_cost_computing
    
    def get_name(self):
        return self.name
    
    def register_components_for_cost_computing(self):
        args = inspect.getfullargspec(self.compute_cost_per_token)[0][1:]
        self.components_for_cost_computing = [self.__getattribute__(arg) for arg in args]
    
    @abstractclassmethod
    def compute_cost_per_token(self):
        pass
    
    @abstractclassmethod
    def render(self):
        pass
    
    def set_name(self, name):
        self.name = name
    
    def set_formula(self, formula):
        self.formula = formula
    
    def get_formula(self):
        return self.formula
    
    def set_latency(self, latency):
        self.latency = latency
    
    def get_latency(self):
        return self.latency

class OpenAIModel(BaseTCOModel):

    def __init__(self):
        self.set_name("(SaaS) OpenAI")
        self.set_formula(r"""$CR = \frac{CIT\_1K \times IT + COT\_1K \times OT}{1000}$  <br>
                         with: <br>
                         CR = Cost per Request <br>
                         CIT_1K = Cost per 1000 Input Tokens <br>
                         COT_1K = Cost per 1000 Output Tokens <br>
                         IT = Input Tokens <br>
                         OT = Output Tokens
                         """)
        self.latency = "15s" #Default value for GPT4
        super().__init__()

    def render(self):
        def on_model_change(model):
            
            if model == "GPT-4":
                self.latency = "15s"
                return gr.Dropdown.update(choices=["8K", "32K"])
            else:
                self.latency = "5s"
                return gr.Dropdown.update(choices=["4K", "16K"], value="4K")

        def define_cost_per_token(model, context_length):
            if model == "GPT-4" and context_length == "8K":
                cost_per_1k_input_tokens = 0.03
                cost_per_1k_output_tokens = 0.06
            elif model == "GPT-4" and context_length == "32K":
                cost_per_1k_input_tokens = 0.06
                cost_per_1k_output_tokens = 0.12
            elif model == "GPT-3.5" and context_length == "4K":
                cost_per_1k_input_tokens = 0.0015
                cost_per_1k_output_tokens = 0.002
            else:
                cost_per_1k_input_tokens = 0.003
                cost_per_1k_output_tokens = 0.004
            return cost_per_1k_input_tokens, cost_per_1k_output_tokens
        
        self.model = gr.Dropdown(["GPT-4", "GPT-3.5 Turbo"], value="GPT-4",
                                 label="OpenAI models",
                                 interactive=True, visible=False)
        self.context_length = gr.Dropdown(["8K", "32K"], value="8K", interactive=True,
                                          label="Context size",
                                          visible=False, info="Number of tokens the model considers when processing text")
        self.input_tokens_cost_per_second = gr.Number(0.03, visible=False,
                                           label="($) Price/1K input prompt tokens",
                                           interactive=False
                                           )
        self.output_tokens_cost_per_second = gr.Number(0.06, visible=False,
                                           label="($) Price/1K output prompt tokens",
                                           interactive=False
                                           )
        self.info = gr.Markdown("The cost per input and output tokens values are from OpenAI's pricing web page [here](https://openai.com/pricing)", interactive=False, visible=False)
        self.model.change(on_model_change, inputs=self.model, outputs=self.context_length).then(define_cost_per_token, inputs=[self.model, self.context_length], outputs=[self.input_tokens_cost_per_second, self.output_tokens_cost_per_second])
        self.context_length.change(define_cost_per_token, inputs=[self.model, self.context_length], outputs=[self.input_tokens_cost_per_second, self.output_tokens_cost_per_second])
        
        self.labor = gr.Number(0, visible=False, 
                                label="($) Labor cost per month", 
                                interactive=True
                                )

    def compute_cost_per_token(self, input_tokens_cost_per_second, output_tokens_cost_per_second, labor):
        cost_per_input_token = (input_tokens_cost_per_second / 1000) 
        cost_per_output_token = (output_tokens_cost_per_second / 1000)

        return cost_per_input_token, cost_per_output_token, labor

class OpenSourceLlama2Model(BaseTCOModel):
    
    def __init__(self):
        self.set_name("(Open source) Llama 2 70B")
        self.set_formula(r"""$CR = \frac{CIT\_1K \times IT + COT\_1K \times OT}{1000}$  <br>
                         with: <br>
                         CR = Cost per Request <br>
                         CIT_1K = Cost per 1000 Input Tokens <br>
                         COT_1K = Cost per 1000 Output Tokens <br>
                         IT = Input Tokens <br>
                         OT = Output Tokens
                         """)
        self.set_latency("27s")
        super().__init__()
    
    def render(self):
        
        self.vm = gr.Textbox(value="2x A100 80GB NVLINK", 
                              visible=False,
                              label="Instance of VM with GPU",
                              )
        self.vm_cost_per_hour = gr.Number(2.21, label="VM instance cost ($) per hour", info="Note that this is the cost for a single VM instance, it is doubled in our case since two GPUs are needed",
                                      interactive=False, visible=False)
        self.input_tokens_cost_per_second = gr.Number(0.00052, visible=False,
                                           label="($) Price/1K input prompt tokens",
                                           interactive=False
                                           )
        self.output_tokens_cost_per_second = gr.Number(0.06656, visible=False,
                                           label="($) Price/1K output prompt tokens",
                                           interactive=False
                                           )
        self.info = gr.Markdown("For the Llama2-70B model, we took the cost per input and output tokens values from the benchmark results [here](https://www.cursor.so/blog/llama-inference#user-content-fn-llama-paper)", interactive=False, visible=False)
        
        self.labor = gr.Number(1000, visible=False, 
                                label="($) Labor cost per month",
                                interactive=True
                                )
        
        # self.used = gr.Slider(minimum=0.01, value=30., step=0.01, label="% used", 
        #                            info="Percentage of time the GPU is used",
        #                            interactive=True,
        #                            visible=False)

    def compute_cost_per_token(self, input_tokens_cost_per_second, output_tokens_cost_per_second, labor):
        cost_per_input_token = (input_tokens_cost_per_second / 1000) 
        cost_per_output_token = (output_tokens_cost_per_second / 1000)
        return cost_per_input_token,  cost_per_output_token, labor

class CohereModel(BaseTCOModel):
    
    def __init__(self):
        self.set_name("(SaaS) Cohere")
        self.set_formula(r"""$CR = \frac{CT\_1M \times (IT + OT)}{1000000}$  <br>
                         with: <br>
                         CR = Cost per Request <br>
                         CT_1M = Cost per one million Tokens (from Cohere's pricing web page) <br>
                         IT = Input Tokens <br>
                         OT = Output Tokens 
                         """)
        self.set_latency("")
        super().__init__()
    
    def render(self):
        self.model = gr.Dropdown(["Default", "Custom"], value="Default",
                                 label="Model",
                                 interactive=True, visible=False)
        if self.use_case == "Summarize":
            self.model:  gr.Dropdown.update(choices=["Default"])
        elif self.use_case == "Question-answering":
            self.model: gr.Dropdown.update(choices=["Default", "Custom"])
        else: 
            self.model: gr.Dropdown.update(choices=["Default", "Custom"])
            
        self.labor = gr.Number(0, visible=False, 
                                label="($) Labor cost per month", 
                                interactive=True
                                )

    def compute_cost_per_token(self, model, labor):
        """Cost per token = """
        use_case = self.use_case  

        if use_case == "Generate":
            if model == "Default":
                cost_per_1M_tokens = 15
            else: 
                cost_per_1M_tokens = 30
        elif use_case == "Summarize":
            cost_per_1M_tokens = 15
        else: 
            cost_per_1M_tokens = 200
        cost_per_input_token = cost_per_1M_tokens / 1000000
        cost_per_output_token = cost_per_1M_tokens / 1000000

        return cost_per_input_token, cost_per_output_token, labor

class ModelPage:
    
    def __init__(self, Models: BaseTCOModel):
        self.models: list[BaseTCOModel] = []
        for Model in Models:
            model = Model()
            self.models.append(model)

    def render(self):
        for model in self.models:
            model.render()
            model.register_components_for_cost_computing() 

    def get_all_components(self) -> list[Component]:
        output = []
        for model in self.models:
            output += model.get_components()
        return output
    
    def get_all_components_for_cost_computing(self) -> list[Component]:
        output = []
        for model in self.models:
            output += model.get_components_for_cost_computing()
        return output

    def make_model_visible(self, name:str, use_case: gr.Dropdown):
        # First decide which indexes
        output = []
        for model in self.models:
            if model.get_name() == name:
                output+= [gr.update(visible=True)] * len(model.get_components()) 
                # Set use_case and num_users values in the model
                model.use_case = use_case
            else:
                output+= [gr.update(visible=False)] * len(model.get_components())
        return output
    
    def compute_cost_per_token(self, *args):
        begin=0
        current_model = args[-3]
        current_input_tokens = args[-2]
        current_output_tokens = args[-1]
        for model in self.models:
            model_n_args = len(model.get_components_for_cost_computing())
            if current_model == model.get_name():
                
                model_args = args[begin:begin+model_n_args]
                cost_per_input_token, cost_per_output_token, labor_cost = model.compute_cost_per_token(*model_args)
                model_tco = cost_per_input_token * current_input_tokens + cost_per_output_token * current_output_tokens 
                formula = model.get_formula()
                latency = model.get_latency()
                
                return f"Model {current_model} has a cost/request of: ${model_tco}", model_tco, formula, f"The average latency of this model is {latency}", labor_cost
            
            begin = begin+model_n_args