Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import json | |
import random | |
import re | |
import os | |
from abc import ABC, abstractmethod | |
from typing import List, Dict, Union, Optional | |
from huggingface_hub import InferenceClient | |
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
from transformers import AutoTokenizer | |
# from config import * | |
ROLE_SYSTEM = 'system' | |
ROLE_USER = 'user' | |
ROLE_ASSISTANT = 'assistant' | |
SUPPORTED_MISTRAL_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2'] | |
SUPPORTED_NOUS_MODELS = ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO', | |
'NousResearch/Nous-Hermes-2-Mistral-7B-DPO'] | |
SUPPORTED_LLAMA_MODELS = ['meta-llama/Llama-2-70b-chat-hf', | |
'meta-llama/Llama-2-13b-chat-hf', | |
'meta-llama/Llama-2-7b-chat-hf'] | |
ALL_SUPPORTED_MODELS = SUPPORTED_MISTRAL_MODELS + SUPPORTED_NOUS_MODELS + SUPPORTED_LLAMA_MODELS | |
def select_model(model_name: str, system_prompt: str, **kwargs) -> Model: | |
if model_name in SUPPORTED_MISTRAL_MODELS: | |
return MistralModel(system_prompt, model_name) | |
elif model_name in SUPPORTED_NOUS_MODELS: | |
return NousHermesModel(system_prompt, model_name) | |
elif model_name in SUPPORTED_LLAMA_MODELS: | |
return LlamaModel(system_prompt, model_name) | |
else: | |
raise ValueError(f'Model {model_name} not supported') | |
class Model(ABC): | |
name: str | |
messages: List[Dict[str, str]] | |
system_prompt: str | |
def __init__(self, model_name: str, system_prompt: str): | |
self.name = model_name | |
self.system_prompt = system_prompt | |
self.messages = [ | |
{'role': ROLE_SYSTEM, 'content': system_prompt} | |
] | |
def __call__(self, *args, **kwargs) -> Union[str, Dict]: | |
raise NotImplementedError | |
def add_message(self, role: str, content: str): | |
assert role in [ROLE_SYSTEM, ROLE_USER, ROLE_ASSISTANT] | |
self.messages.append({'role': role, 'content': content}) | |
def clear_conversations(self): | |
self.messages.clear() | |
self.add_message(ROLE_SYSTEM, self.system_prompt) | |
def __str__(self) -> str: | |
return self.name | |
def __repr__(self) -> str: | |
return self.name | |
class HFAPIModel(Model): | |
def __call__(self, user_prompt: str, *args, | |
use_json: bool = False, | |
temperature: float = 0, | |
timeout: float = None, | |
cache: bool = False, | |
json_retry_count: int = 5, | |
**kwargs) -> Union[str, Dict]: | |
""" | |
Returns the model's response. | |
If use_json = True, will try its best to return a json dict, but not guaranteed. | |
If we cannot parse the JSON, we will return the response string directly. | |
""" | |
self.add_message(ROLE_USER, user_prompt) | |
response = self.get_response(temperature, use_json, timeout, cache) | |
if use_json: | |
for i in range(json_retry_count): | |
# cache only if both instruct to do and first try | |
response = self.get_response(temperature, use_json, timeout, cache and i == 0) | |
json_obj = self.find_first_valid_json(response) | |
if json_obj is not None: | |
response = json_obj | |
break | |
self.add_message(ROLE_ASSISTANT, response) | |
return response | |
# retry if exception | |
def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> str: | |
# hf_api_token = | |
client = InferenceClient(model=self.name, token=os.getenv('HF_API_TOKEN'), timeout=timeout) | |
# client = InferenceClient(model=self.name, token=random.choice(HF_API_TOKENS), timeout=timeout) | |
if not cache: | |
client.headers["x-use-cache"] = "0" | |
# print(self.formatter(self.messages)) # debug | |
r = client.text_generation(self.format_messages(), | |
do_sample=temperature > 0, | |
temperature=temperature if temperature > 0 else None, | |
max_new_tokens=4096) | |
return r | |
def format_messages(self) -> str: | |
raise NotImplementedError | |
def get_short_name(self) -> str: | |
""" | |
Returns the last part of the model name. | |
For example, "mistralai/Mixtral-8x7B-Instruct-v0.1" -> "Mixtral-8x7B-Instruct-v0.1" | |
""" | |
return self.name.split('/')[-1] | |
def find_first_valid_json(s) -> Optional[Dict]: | |
s = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', lambda m: m.group(0)[1:], s) # remove all invalid escapes chars | |
for i in range(len(s)): | |
if s[i] != '{': | |
continue | |
for j in range(i + 1, len(s) + 1): | |
if s[j - 1] != '}': | |
continue | |
try: | |
potential_json = s[i:j] | |
json_obj = json.loads(potential_json, strict=False) | |
return json_obj # Return the first valid JSON object found | |
except json.JSONDecodeError: | |
pass # Continue searching if JSON decoding fails | |
return None # Return None if no valid JSON object is found | |
class MistralModel(HFAPIModel): | |
def __init__(self, system_prompt: str, model_name: str = 'mistralai/Mixtral-8x7B-Instruct-v0.1') -> None: | |
assert model_name in ['mistralai/Mixtral-8x7B-Instruct-v0.1', | |
'mistralai/Mistral-7B-Instruct-v0.2'], 'Model not supported' | |
super().__init__(model_name, system_prompt) | |
def format_messages(self) -> str: | |
messages = self.messages | |
# mistral doesn't support system prompt, so we need to convert it to user prompt | |
if messages[0]['role'] == ROLE_SYSTEM: | |
assert len(self.messages) >= 2 | |
messages = [{'role' : ROLE_USER, | |
'content': messages[0]['content'] + '\n' + messages[1]['content']}] + messages[2:] | |
tokenizer = AutoTokenizer.from_pretrained(self.name) | |
r = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, max_length=4096) | |
# print(r) | |
return r | |
class NousHermesModel(HFAPIModel): | |
def __init__(self, system_prompt: str, model_name: str = 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO') -> None: | |
assert model_name in SUPPORTED_NOUS_MODELS, 'Model not supported' | |
super().__init__(model_name, system_prompt) | |
def format_messages(self) -> str: | |
messages = self.messages | |
assert len(messages) >= 2 # must be at least a system and a user | |
assert messages[0]['role'] == ROLE_SYSTEM and messages[1]['role'] == ROLE_USER | |
tokenizer = AutoTokenizer.from_pretrained(self.name) | |
r = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, max_length=4096) | |
# print(r) | |
return r | |
class LlamaModel(HFAPIModel): | |
def __init__(self, system_prompt: str, model_name: str = 'meta-llama/Llama-2-70b-chat-hf') -> None: | |
assert model_name in ['meta-llama/Llama-2-70b-chat-hf', | |
'meta-llama/Llama-2-13b-chat-hf', | |
'meta-llama/Llama-2-7b-chat-hf'], 'Model not supported' | |
super().__init__(model_name, system_prompt) | |
def format_messages(self) -> str: | |
""" | |
<s>[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{user_message} [/INST] | |
""" | |
messages = self.messages | |
assert len(messages) >= 2 # must be at least a system and a user | |
r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n\n{messages[1]["content"]} [/INST]' | |
for msg in messages[2:]: | |
role, content = msg['role'], msg['content'] | |
if role == ROLE_SYSTEM: | |
assert ValueError | |
elif role == ROLE_USER: | |
if r.endswith('</s>'): | |
r += '<s>' | |
r += f'[INST] {content} [/INST]' | |
elif role == ROLE_ASSISTANT: | |
r += f'{content}</s>' | |
else: | |
raise ValueError | |
return r |