|
from functools import partial |
|
import os |
|
import re |
|
import time |
|
from xml.parsers.expat import model |
|
|
|
|
|
if os.environ.get("SPACES_ZERO_GPU") is not None: |
|
import spaces |
|
else: |
|
|
|
class spaces: |
|
@staticmethod |
|
def GPU(func): |
|
def wrapper(*args, **kwargs): |
|
return func(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
from transformers import pipeline as hf_pipeline |
|
|
|
import litellm |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
pipeline = hf_pipeline( |
|
"text-generation", |
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
) |
|
pipeline.to('cuda') |
|
|
|
|
|
class ModelPrediction: |
|
def __init__(self): |
|
self.model_name2pred_func = { |
|
"gpt-3.5": self._init_model_prediction("gpt-3.5"), |
|
"gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"), |
|
"o1-mini": self._init_model_prediction("o1-mini"), |
|
"QwQ": self._init_model_prediction("QwQ"), |
|
"DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction( |
|
"DeepSeek-R1-Distill-Llama-70B" |
|
), |
|
"llama-8": self._init_model_prediction("llama-8"), |
|
} |
|
|
|
self._model_name = None |
|
self._pipeline = None |
|
self.base_prompt= ( |
|
"Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n" |
|
" Question\n" |
|
"{question}\n" |
|
"Database Schema\n" |
|
"{db_schema}\n" |
|
) |
|
|
|
def _reset_pipeline(self, model_name): |
|
if self._model_name != model_name: |
|
self._model_name = model_name |
|
self._pipeline = None |
|
|
|
@staticmethod |
|
def _extract_answer_from_pred(pred: str) -> str: |
|
|
|
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL) |
|
if matches: |
|
return matches[-1].replace("```", "").replace("sql", "").strip() |
|
else: |
|
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL) |
|
return matches[-1].strip() if matches else pred |
|
|
|
|
|
def make_prediction(self, question, db_schema, model_name, prompt=None): |
|
if model_name not in self.model_name2pred_func: |
|
raise ValueError( |
|
"Model not supported", |
|
"supported models are", |
|
self.model_name2pred_func.keys(), |
|
) |
|
|
|
|
|
prompt = prompt or self.base_prompt |
|
|
|
|
|
start_time = time.time() |
|
prediction = self.model_name2pred_func[model_name](prompt) |
|
end_time = time.time() |
|
prediction["response_parsed"] = self._extract_answer_from_pred( |
|
prediction["response"] |
|
) |
|
prediction['time'] = end_time - start_time |
|
|
|
return prediction |
|
|
|
|
|
def predict_with_api(self, prompt, model_name): |
|
response = litellm.completion( |
|
model=model_name, |
|
messages=[{"role": "user", "content": prompt}], |
|
num_retries=2, |
|
) |
|
response_text = response["choices"][0]["message"]["content"] |
|
return { |
|
"response": response_text, |
|
"cost": response._hidden_params["response_cost"], |
|
} |
|
|
|
@spaces.GPU |
|
def predict_with_hf(self, prompt, model_name): |
|
outputs = pipeline( |
|
[{"role": "user", "content": prompt}], |
|
max_new_tokens=256, |
|
) |
|
response = outputs[0]["generated_text"][-1] |
|
return {"response": response, "cost": 0.0} |
|
|
|
def _init_model_prediction(self, model_name): |
|
predict_fun = self.predict_with_api |
|
if "gpt-3.5" in model_name: |
|
model_name = "openai/gpt-3.5-turbo-0125" |
|
elif "gpt-4o-mini" in model_name: |
|
model_name = "openai/gpt-4o-mini-2024-07-18" |
|
elif "o1-mini" in model_name: |
|
model_name = "openai/o1-mini-2024-09-12" |
|
elif "QwQ" in model_name: |
|
model_name = "together_ai/Qwen/QwQ-32B" |
|
elif "DeepSeek-R1-Distill-Llama-70B" in model_name: |
|
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B" |
|
elif "llama-8" in model_name: |
|
model_name = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
predict_fun = self.predict_with_hf |
|
else: |
|
raise ValueError("Model forbidden") |
|
|
|
return partial(predict_fun, model_name=model_name) |
|
|
|
|