import json import random import re from collections import defaultdict from typing import Any, List, Dict, Callable, Union, Optional # from vllm import LLM, SamplingParams import regex import numpy as np from huggingface_hub import InferenceClient from tqdm import tqdm # from config import * ROLE_SYSTEM = 'system' ROLE_USER = 'user' ROLE_ASSISTANT = 'assistant' CHAT_FORMATS = { "mistralai": "[INST] {prompt} [/INST]", "openchat": "GPT4 User: {prompt}<|end_of_turn|>GPT4 Assistant:", "meta-llama": """[INST] <> You answer questions directly. <> {prompt}[/INST]""", "mosaicml": """<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|> <|im_start|>user {prompt}<|im_end|> <|im_start|>assistant""", "lmsys": "USER: {prompt}\nASSISTANT:", } LLAMA_TEMPLATE = """[INST] <> {system_prompt} <> {user_message} [/INST]""" MISTRAL_TEMPLATE = """[INST] <> {system_prompt} <> {user_message} [/INST]""" YI_34B_TEMPLATE = """<|im_start|>system {system_prompt}<|im_end|> <|im_start|>user {user_message}<|im_end|> <|im_start|>assistant """ def extract_json(text: str) -> Dict: # json_string_match = re.search(r"json\s+(.+?)\s+", text, re.DOTALL) # Assume it's goind to be like: "Guess": "A" or "Guess": "B" # Now it's either true or false # print(text) text = text.replace('\\', '\\\\') try: rslt = json.loads(text) except Exception as e: # print(e) # print(text) rslt = None return rslt def mixtral_prompt_formatter(messages: List[Dict[str, str]]) -> str: """ [INST] <> {system_prompt} <> {user_prompt} [/INST] """ assert len(messages) >= 2 # must be at least a system and a user r = f'[INST] <>\n{messages[0]["content"]}\n<>\n{messages[1]["content"]} [/INST]' for msg in messages[2:]: role, content = msg['role'], msg['content'] if role == ROLE_SYSTEM: assert ValueError elif role == ROLE_USER: if r.endswith(''): r += '' r += f'[INST] {content} [/INST]' elif role == ROLE_ASSISTANT: r += f'{content}' else: raise ValueError return r def llama_prompt_formatter(messages: List[Dict[str, str]]) -> str: """ [INST] <> {system_prompt} <> {user_message} [/INST] """ assert len(messages) >= 2 # must be at least a system and a user r = f'[INST] <>\n{messages[0]["content"]}\n<>\n\n{messages[1]["content"]} [/INST]' for msg in messages[2:]: role, content = msg['role'], msg['content'] if role == ROLE_SYSTEM: assert ValueError elif role == ROLE_USER: if r.endswith(''): r += '' r += f'[INST] {content} [/INST]' elif role == ROLE_ASSISTANT: r += f'{content}' else: raise ValueError return r def yi_prompt_formatter(messages: List[Dict[str, str]]) -> str: """ <|im_start|>system {system_prompt}<|im_end|> <|im_start|>user {user_message}<|im_end|> <|im_start|>assistant """ assert len(messages) >= 2 # must be at least a system and a user r = f'<|im_start|>system\n{messages[0]["content"]}<|im_end|>\n<|im_start|>user\n{messages[1]["content"]}<|im_end|>\n' for i in range(2, len(messages)): msg = messages[i] role, content = msg['role'], msg['content'] if role == ROLE_SYSTEM: assert ValueError elif role == ROLE_USER: r += f'<|im_start|>user\n{content}<|im_end|>\n' if i == len(messages) - 1: r += '<|im_start|>assistant\n' elif role == ROLE_ASSISTANT: r += f'<|im_start|>assistant\n{content}<|im_end|>\n' else: raise ValueError return r def find_first_valid_json(s) -> Optional[Dict]: s = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', lambda m: m.group(0)[1:], s) for i in range(len(s)): if s[i] != '{': continue for j in range(i + 1, len(s) + 1): if s[j - 1] != '}': continue try: potential_json = s[i:j] json_obj = json.loads(potential_json, strict=False) return json_obj # Return the first valid JSON object found except json.JSONDecodeError: pass # Continue searching if JSON decoding fails return None # Return None if no valid JSON object is found class HFAPIModel: model_name: str messages: List[Dict[str, str]] system_prompt: str formatter: Callable[[List[Dict[str, str]]], str] def __init__(self, system_prompt: str, model_name: str) -> None: self.system_prompt = system_prompt self.model_name = model_name if 'llama' in model_name: self.formatter = llama_prompt_formatter elif 'mistral' in model_name: self.formatter = mixtral_prompt_formatter else: raise NotImplementedError self.messages = [ {'role': ROLE_SYSTEM, 'content': system_prompt} ] def __call__(self, user_prompt: str, use_json: bool = False, temperature: float = 0, timeout: float = None, cache: bool = True) -> Union[str, Dict]: self.add_message(ROLE_USER, user_prompt) response = self.get_response(temperature, use_json, timeout, cache) self.add_message(ROLE_ASSISTANT, response) return response def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> Union[str, Dict]: """ Returns the model's response. If use_json = True, will try its best to return a json dict, but not guaranteed. If we cannot parse the JSON, we will return the response string directly. """ client = InferenceClient(self.model_name, timeout=timeout) if not cache: client.headers["x-use-cache"] = "0" # print(self.formatter(self.messages)) # debug r = client.text_generation(self.formatter(self.messages), do_sample=temperature > 0, temperature=temperature if temperature > 0 else None, max_new_tokens=512) if use_json: obj = find_first_valid_json(r) if obj is not None: return obj return r def add_message(self, role: str, message: str): self.messages.append({'role': role, 'content': message}) if __name__ == '__main__': # model = GPTModel(system_prompt='You are an AI developed by OpenAI.', model_name=GPT_4_MODEL_NAME) model = HFAPIModel(system_prompt='You are a helpful assistant.', model_name='mistralai/Mixtral-8x7B-Instruct-v0.1') print(model('Who are you?'))