File size: 5,243 Bytes
d7c9e73
 
 
8277386
d7c9e73
 
c2e2aa2
 
 
 
d7c9e73
c2e2aa2
 
 
 
 
d7c9e73
c2e2aa2
 
d7c9e73
6ce82f5
d7c9e73
 
ffec641
d4aa01a
ffec641
d4aa01a
 
c2e2aa2
ccba587
40354df
 
7e9dee7
e993e1b
40354df
 
 
c2e2aa2
220b4dd
d7c9e73
6ce82f5
 
 
 
 
d7c9e73
 
6ce82f5
d7c9e73
 
 
 
6ce82f5
 
 
 
 
 
 
d7c9e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffec641
6ce82f5
d7c9e73
 
 
 
 
 
8277386
d7c9e73
6ce82f5
ca89877
6ce82f5
8277386
d7c9e73
8277386
d7c9e73
 
 
8277386
 
d7c9e73
c2e2aa2
6ce82f5
d7c9e73
 
 
 
 
 
ffec641
 
 
 
 
c2e2aa2
d7c9e73
 
6b75ebd
 
40354df
 
 
 
6b75ebd
 
 
 
 
14052a3
d4aa01a
6b75ebd
 
 
 
6ce82f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from functools import partial
import os
import re
import time
from xml.parsers.expat import model

# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:

    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

from transformers import pipeline as hf_pipeline

import litellm

from tqdm import tqdm
import subprocess

# https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/132
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)

pipeline = hf_pipeline(
    "text-generation",
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    model_kwargs={"torch_dtype": 'bfloat16'},
    device_map="auto",
)


class ModelPrediction:
    def __init__(self):
        self.model_name2pred_func = {
            "gpt-3.5": self._init_model_prediction("gpt-3.5"),
            "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
            "o1-mini": self._init_model_prediction("o1-mini"),
            "QwQ": self._init_model_prediction("QwQ"),
            "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
                "DeepSeek-R1-Distill-Llama-70B"
            ),
            "llama-8": self._init_model_prediction("llama-8"),
        }

        self._model_name = None
        self._pipeline = None
        self.base_prompt= (
            "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
            " Question\n"
            "{question}\n"
            "Database Schema\n"
            "{db_schema}\n"
        )

    def _reset_pipeline(self, model_name):
        if self._model_name != model_name:
            self._model_name = model_name
            self._pipeline = None

    @staticmethod
    def _extract_answer_from_pred(pred: str) -> str:
        # extract with regex everything is between <answer> and </answer>
        matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
        if matches:
            return matches[-1].replace("```", "").replace("sql", "").strip()
        else:
            matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
            return matches[-1].strip() if matches else pred


    def make_prediction(self, question, db_schema,  model_name, prompt=None):
        if model_name not in self.model_name2pred_func:
            raise ValueError(
                "Model not supported",
                "supported models are",
                self.model_name2pred_func.keys(),
            )
        

        prompt = prompt or self.base_prompt
        #prompt = prompt.format(question=question, db_schema=db_schema)

        start_time = time.time()
        prediction = self.model_name2pred_func[model_name](prompt)
        end_time = time.time()
        prediction["response_parsed"] = self._extract_answer_from_pred(
            prediction["response"]
        )
        prediction['time'] = end_time - start_time

        return prediction

   
    def predict_with_api(self, prompt, model_name):  # -> dict[str, Any | float]:
        response = litellm.completion(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            num_retries=2,
        )
        response_text = response["choices"][0]["message"]["content"]
        return {
            "response": response_text,
            "cost": response._hidden_params["response_cost"],
        }

    @spaces.GPU
    def predict_with_hf(self, prompt, model_name):  # -> dict[str, Any | float]:
        
        start_time = time.time()
        outputs = pipeline(
            [{"role": "user", "content": prompt}],
            max_new_tokens=256,
        )
        end_time = time.time()
        elapsed_time = end_time - start_time
        # inference endpoint costs HF per Hour 3.6$/h -> 0.001 $ per second
        # https://huggingface.co/docs/inference-endpoints/en/pricing?utm_source=chatgpt.com
        cost_per_second=0.001
        response = outputs[0]["generated_text"][-1]['content']
        print(response)
        return {
            "response": response, 
            "cost": elapsed_time * cost_per_second
        }

    def _init_model_prediction(self, model_name):
        predict_fun = self.predict_with_api
        if "gpt-3.5" in model_name:
            model_name = "openai/gpt-3.5-turbo-0125"
        elif "gpt-4o-mini" in model_name:
            model_name = "openai/gpt-4o-mini-2024-07-18"
        elif "o1-mini" in model_name:
            model_name = "openai/o1-mini-2024-09-12"
        elif "QwQ" in model_name:
            model_name = "together_ai/Qwen/QwQ-32B"
        elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
            model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
        elif "llama-8" in model_name:
            model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
            predict_fun = self.predict_with_hf
        else:
            raise ValueError("Model forbidden")

        return partial(predict_fun, model_name=model_name)