Spaces:
Runtime error
Runtime error
import json | |
import random | |
import re | |
from collections import defaultdict | |
from typing import Any, List, Dict, Callable, Union, Optional | |
# from vllm import LLM, SamplingParams | |
import regex | |
import numpy as np | |
from huggingface_hub import InferenceClient | |
from tqdm import tqdm | |
# from config import * | |
ROLE_SYSTEM = 'system' | |
ROLE_USER = 'user' | |
ROLE_ASSISTANT = 'assistant' | |
CHAT_FORMATS = { | |
"mistralai": "<s>[INST] {prompt} [/INST]", | |
"openchat": "GPT4 User: {prompt}<|end_of_turn|>GPT4 Assistant:", | |
"meta-llama": """[INST] <<SYS>> | |
You answer questions directly. | |
<</SYS>> | |
{prompt}[/INST]""", | |
"mosaicml": """<|im_start|>system | |
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|> | |
<|im_start|>user | |
{prompt}<|im_end|> | |
<|im_start|>assistant""", | |
"lmsys": "USER: {prompt}\nASSISTANT:", | |
} | |
LLAMA_TEMPLATE = """<s>[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{user_message} [/INST]""" | |
MISTRAL_TEMPLATE = """<s>[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> {user_message} [/INST]""" | |
YI_34B_TEMPLATE = """<|im_start|>system | |
{system_prompt}<|im_end|> | |
<|im_start|>user | |
{user_message}<|im_end|> | |
<|im_start|>assistant | |
""" | |
def extract_json(text: str) -> Dict: | |
# json_string_match = re.search(r"json\s+(.+?)\s+", text, re.DOTALL) | |
# Assume it's goind to be like: "Guess": "A" or "Guess": "B" | |
# Now it's either true or false | |
# print(text) | |
text = text.replace('\\', '\\\\') | |
try: | |
rslt = json.loads(text) | |
except Exception as e: | |
# print(e) | |
# print(text) | |
rslt = None | |
return rslt | |
def mixtral_prompt_formatter(messages: List[Dict[str, str]]) -> str: | |
""" | |
<s>[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{user_prompt} [/INST] | |
""" | |
assert len(messages) >= 2 # must be at least a system and a user | |
r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n{messages[1]["content"]} [/INST]' | |
for msg in messages[2:]: | |
role, content = msg['role'], msg['content'] | |
if role == ROLE_SYSTEM: | |
assert ValueError | |
elif role == ROLE_USER: | |
if r.endswith('</s>'): | |
r += '<s>' | |
r += f'[INST] {content} [/INST]' | |
elif role == ROLE_ASSISTANT: | |
r += f'{content}</s>' | |
else: | |
raise ValueError | |
return r | |
def llama_prompt_formatter(messages: List[Dict[str, str]]) -> str: | |
""" | |
<s>[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{user_message} [/INST] | |
""" | |
assert len(messages) >= 2 # must be at least a system and a user | |
r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n\n{messages[1]["content"]} [/INST]' | |
for msg in messages[2:]: | |
role, content = msg['role'], msg['content'] | |
if role == ROLE_SYSTEM: | |
assert ValueError | |
elif role == ROLE_USER: | |
if r.endswith('</s>'): | |
r += '<s>' | |
r += f'[INST] {content} [/INST]' | |
elif role == ROLE_ASSISTANT: | |
r += f'{content}</s>' | |
else: | |
raise ValueError | |
return r | |
def yi_prompt_formatter(messages: List[Dict[str, str]]) -> str: | |
""" | |
<|im_start|>system | |
{system_prompt}<|im_end|> | |
<|im_start|>user | |
{user_message}<|im_end|> | |
<|im_start|>assistant | |
""" | |
assert len(messages) >= 2 # must be at least a system and a user | |
r = f'<|im_start|>system\n{messages[0]["content"]}<|im_end|>\n<|im_start|>user\n{messages[1]["content"]}<|im_end|>\n' | |
for i in range(2, len(messages)): | |
msg = messages[i] | |
role, content = msg['role'], msg['content'] | |
if role == ROLE_SYSTEM: | |
assert ValueError | |
elif role == ROLE_USER: | |
r += f'<|im_start|>user\n{content}<|im_end|>\n' | |
if i == len(messages) - 1: | |
r += '<|im_start|>assistant\n' | |
elif role == ROLE_ASSISTANT: | |
r += f'<|im_start|>assistant\n{content}<|im_end|>\n' | |
else: | |
raise ValueError | |
return r | |
def find_first_valid_json(s) -> Optional[Dict]: | |
s = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', lambda m: m.group(0)[1:], s) | |
for i in range(len(s)): | |
if s[i] != '{': | |
continue | |
for j in range(i + 1, len(s) + 1): | |
if s[j - 1] != '}': | |
continue | |
try: | |
potential_json = s[i:j] | |
json_obj = json.loads(potential_json, strict=False) | |
return json_obj # Return the first valid JSON object found | |
except json.JSONDecodeError: | |
pass # Continue searching if JSON decoding fails | |
return None # Return None if no valid JSON object is found | |
class HFAPIModel: | |
model_name: str | |
messages: List[Dict[str, str]] | |
system_prompt: str | |
formatter: Callable[[List[Dict[str, str]]], str] | |
def __init__(self, system_prompt: str, model_name: str) -> None: | |
self.system_prompt = system_prompt | |
self.model_name = model_name | |
if 'llama' in model_name: | |
self.formatter = llama_prompt_formatter | |
elif 'mistral' in model_name: | |
self.formatter = mixtral_prompt_formatter | |
else: | |
raise NotImplementedError | |
self.messages = [ | |
{'role': ROLE_SYSTEM, 'content': system_prompt} | |
] | |
def __call__(self, user_prompt: str, use_json: bool = False, | |
temperature: float = 0, timeout: float = None, | |
cache: bool = True) -> Union[str, Dict]: | |
self.add_message(ROLE_USER, user_prompt) | |
response = self.get_response(temperature, use_json, timeout, cache) | |
self.add_message(ROLE_ASSISTANT, response) | |
return response | |
def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> Union[str, Dict]: | |
""" | |
Returns the model's response. | |
If use_json = True, will try its best to return a json dict, but not guaranteed. | |
If we cannot parse the JSON, we will return the response string directly. | |
""" | |
client = InferenceClient(self.model_name, timeout=timeout) | |
if not cache: | |
client.headers["x-use-cache"] = "0" | |
# print(self.formatter(self.messages)) # debug | |
r = client.text_generation(self.formatter(self.messages), | |
do_sample=temperature > 0, | |
temperature=temperature if temperature > 0 else None, | |
max_new_tokens=512) | |
if use_json: | |
obj = find_first_valid_json(r) | |
if obj is not None: | |
return obj | |
return r | |
def add_message(self, role: str, message: str): | |
self.messages.append({'role': role, 'content': message}) | |
if __name__ == '__main__': | |
# model = GPTModel(system_prompt='You are an AI developed by OpenAI.', model_name=GPT_4_MODEL_NAME) | |
model = HFAPIModel(system_prompt='You are a helpful assistant.', model_name='mistralai/Mixtral-8x7B-Instruct-v0.1') | |
print(model('Who are you?')) |