LLM-model-cards / models.py
Blair Yang
selector
b11f272
raw
history blame
7.06 kB
import json
import random
import re
from collections import defaultdict
from typing import Any, List, Dict, Callable, Union, Optional
# from vllm import LLM, SamplingParams
import regex
import numpy as np
from huggingface_hub import InferenceClient
from tqdm import tqdm
# from config import *
ROLE_SYSTEM = 'system'
ROLE_USER = 'user'
ROLE_ASSISTANT = 'assistant'
CHAT_FORMATS = {
"mistralai": "<s>[INST] {prompt} [/INST]",
"openchat": "GPT4 User: {prompt}<|end_of_turn|>GPT4 Assistant:",
"meta-llama": """[INST] <<SYS>>
You answer questions directly.
<</SYS>>
{prompt}[/INST]""",
"mosaicml": """<|im_start|>system
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant""",
"lmsys": "USER: {prompt}\nASSISTANT:",
}
LLAMA_TEMPLATE = """<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{user_message} [/INST]"""
MISTRAL_TEMPLATE = """<s>[INST] <<SYS>>
{system_prompt}
<</SYS>> {user_message} [/INST]"""
YI_34B_TEMPLATE = """<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
"""
def extract_json(text: str) -> Dict:
# json_string_match = re.search(r"json\s+(.+?)\s+", text, re.DOTALL)
# Assume it's goind to be like: "Guess": "A" or "Guess": "B"
# Now it's either true or false
# print(text)
text = text.replace('\\', '\\\\')
try:
rslt = json.loads(text)
except Exception as e:
# print(e)
# print(text)
rslt = None
return rslt
def mixtral_prompt_formatter(messages: List[Dict[str, str]]) -> str:
"""
<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{user_prompt} [/INST]
"""
assert len(messages) >= 2 # must be at least a system and a user
r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n{messages[1]["content"]} [/INST]'
for msg in messages[2:]:
role, content = msg['role'], msg['content']
if role == ROLE_SYSTEM:
assert ValueError
elif role == ROLE_USER:
if r.endswith('</s>'):
r += '<s>'
r += f'[INST] {content} [/INST]'
elif role == ROLE_ASSISTANT:
r += f'{content}</s>'
else:
raise ValueError
return r
def llama_prompt_formatter(messages: List[Dict[str, str]]) -> str:
"""
<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{user_message} [/INST]
"""
assert len(messages) >= 2 # must be at least a system and a user
r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n\n{messages[1]["content"]} [/INST]'
for msg in messages[2:]:
role, content = msg['role'], msg['content']
if role == ROLE_SYSTEM:
assert ValueError
elif role == ROLE_USER:
if r.endswith('</s>'):
r += '<s>'
r += f'[INST] {content} [/INST]'
elif role == ROLE_ASSISTANT:
r += f'{content}</s>'
else:
raise ValueError
return r
def yi_prompt_formatter(messages: List[Dict[str, str]]) -> str:
"""
<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
"""
assert len(messages) >= 2 # must be at least a system and a user
r = f'<|im_start|>system\n{messages[0]["content"]}<|im_end|>\n<|im_start|>user\n{messages[1]["content"]}<|im_end|>\n'
for i in range(2, len(messages)):
msg = messages[i]
role, content = msg['role'], msg['content']
if role == ROLE_SYSTEM:
assert ValueError
elif role == ROLE_USER:
r += f'<|im_start|>user\n{content}<|im_end|>\n'
if i == len(messages) - 1:
r += '<|im_start|>assistant\n'
elif role == ROLE_ASSISTANT:
r += f'<|im_start|>assistant\n{content}<|im_end|>\n'
else:
raise ValueError
return r
def find_first_valid_json(s) -> Optional[Dict]:
s = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', lambda m: m.group(0)[1:], s)
for i in range(len(s)):
if s[i] != '{':
continue
for j in range(i + 1, len(s) + 1):
if s[j - 1] != '}':
continue
try:
potential_json = s[i:j]
json_obj = json.loads(potential_json, strict=False)
return json_obj # Return the first valid JSON object found
except json.JSONDecodeError:
pass # Continue searching if JSON decoding fails
return None # Return None if no valid JSON object is found
class HFAPIModel:
model_name: str
messages: List[Dict[str, str]]
system_prompt: str
formatter: Callable[[List[Dict[str, str]]], str]
def __init__(self, system_prompt: str, model_name: str) -> None:
self.system_prompt = system_prompt
self.model_name = model_name
if 'llama' in model_name:
self.formatter = llama_prompt_formatter
elif 'mistral' in model_name:
self.formatter = mixtral_prompt_formatter
else:
raise NotImplementedError
self.messages = [
{'role': ROLE_SYSTEM, 'content': system_prompt}
]
def __call__(self, user_prompt: str, use_json: bool = False,
temperature: float = 0, timeout: float = None,
cache: bool = True) -> Union[str, Dict]:
self.add_message(ROLE_USER, user_prompt)
response = self.get_response(temperature, use_json, timeout, cache)
self.add_message(ROLE_ASSISTANT, response)
return response
def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> Union[str, Dict]:
"""
Returns the model's response.
If use_json = True, will try its best to return a json dict, but not guaranteed.
If we cannot parse the JSON, we will return the response string directly.
"""
client = InferenceClient(self.model_name, timeout=timeout)
if not cache:
client.headers["x-use-cache"] = "0"
# print(self.formatter(self.messages)) # debug
r = client.text_generation(self.formatter(self.messages),
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else None,
max_new_tokens=512)
if use_json:
obj = find_first_valid_json(r)
if obj is not None:
return obj
return r
def add_message(self, role: str, message: str):
self.messages.append({'role': role, 'content': message})
if __name__ == '__main__':
# model = GPTModel(system_prompt='You are an AI developed by OpenAI.', model_name=GPT_4_MODEL_NAME)
model = HFAPIModel(system_prompt='You are a helpful assistant.', model_name='mistralai/Mixtral-8x7B-Instruct-v0.1')
print(model('Who are you?'))