LLM-As-Chatbot / global_vars.py
chansung's picture
update
88f55d9
raw
history blame
6.89 kB
import gc
import yaml
import json
import torch
from transformers import GenerationConfig
from models import alpaca, stablelm, koalpaca, flan_alpaca, mpt
from models import camel, t5_vicuna, vicuna, starchat, redpajama, bloom
from models import baize, guanaco, falcon, kullm, replit, airoboros
from models import samantha_vicuna
from utils import get_chat_interface, get_chat_manager
model_infos = json.load(open("model_cards.json"))
def get_model_type(model_info):
base_url = model_info["hub(base)"]
ft_ckpt_url = model_info["hub(ckpt)"]
model_type_tmp = "alpaca"
if "llms/wizardlm" in base_url.lower():
model_type_tmp = "wizardlm"
elif "chronos" in base_url.lower():
model_type_tmp = "chronos"
elif "lazarus" in base_url.lower():
model_type_tmp = "lazarus"
elif "samantha" in base_url.lower():
model_type_tmp = "samantha-vicuna"
elif "airoboros" in base_url.lower():
model_type_tmp = "airoboros"
elif "replit" in base_url.lower():
model_type_tmp = "replit-instruct"
elif "kullm" in base_url.lower():
model_type_tmp = "kullm-polyglot"
elif "nous-hermes" in base_url.lower():
model_type_tmp = "nous-hermes"
elif "guanaco" in base_url.lower():
model_type_tmp = "guanaco"
elif "wizardlm-uncensored-falcon" in base_url.lower():
model_type_tmp = "wizard-falcon"
elif "falcon" in base_url.lower():
model_type_tmp = "falcon"
elif "baize" in base_url.lower():
model_type_tmp = "baize"
elif "stable-vicuna" in base_url.lower():
model_type_tmp = "stable-vicuna"
elif "vicuna" in base_url.lower():
model_type_tmp = "vicuna"
elif "mpt" in base_url.lower():
model_type_tmp = "mpt"
elif "redpajama-incite-7b-instruct" in base_url.lower():
model_type_tmp = "redpajama-instruct"
elif "redpajama" in base_url.lower():
model_type_tmp = "redpajama"
elif "starchat" in base_url.lower():
model_type_tmp = "starchat"
elif "camel" in base_url.lower():
model_type_tmp = "camel"
elif "flan-alpaca" in base_url.lower():
model_type_tmp = "flan-alpaca"
elif "openassistant/stablelm" in base_url.lower():
model_type_tmp = "os-stablelm"
elif "stablelm" in base_url.lower():
model_type_tmp = "stablelm"
elif "fastchat-t5" in base_url.lower():
model_type_tmp = "t5-vicuna"
elif "koalpaca-polyglot" in base_url.lower():
model_type_tmp = "koalpaca-polyglot"
elif "alpacagpt4" in ft_ckpt_url.lower():
model_type_tmp = "alpaca-gpt4"
elif "alpaca" in ft_ckpt_url.lower():
model_type_tmp = "alpaca"
elif "llama-deus" in ft_ckpt_url.lower():
model_type_tmp = "llama-deus"
elif "vicuna-lora-evolinstruct" in ft_ckpt_url.lower():
model_type_tmp = "evolinstruct-vicuna"
elif "alpacoom" in ft_ckpt_url.lower():
model_type_tmp = "alpacoom"
elif "guanaco" in ft_ckpt_url.lower():
model_type_tmp = "guanaco"
else:
print("unsupported model type")
return model_type_tmp
def initialize_globals():
global models, tokenizers
models = []
model_names = [
"baize-7b",
# "evolinstruct-vicuna-13b",
"guanaco-7b",
# "nous-hermes-13b"
]
for model_name in model_names:
model_info = model_infos[model_name]
model_thumbnail_tiny = model_info["thumb-tiny"]
model_type = get_model_type(model_info)
print(model_type)
load_model = get_load_model(model_type)
model, tokenizer = load_model(
base=model_info["hub(base)"],
finetuned=model_info["hub(ckpt)"],
mode_cpu=False,
mode_mps=False,
mode_full_gpu=True,
mode_8bit=False,
mode_4bit=False,
force_download_ckpt=False
)
gen_config, gen_config_raw = get_generation_config(
model_info["default_gen_config"]
)
models.append(
{
"model_name": model_name,
"model_thumb_tiny": model_thumbnail_tiny,
"model_type": model_type,
"model": model,
"tokenizer": tokenizer,
"gen_config": gen_config,
"gen_config_raw": gen_config_raw,
"chat_interface": get_chat_interface(model_type),
"chat_manager": get_chat_manager(model_type),
}
)
def get_load_model(model_type):
if model_type == "alpaca" or \
model_type == "alpaca-gpt4" or \
model_type == "llama-deus" or \
model_type == "nous-hermes" or \
model_type == "lazarus" or \
model_type == "chronos" or \
model_type == "wizardlm":
return alpaca.load_model
elif model_type == "stablelm" or model_type == "os-stablelm":
return stablelm.load_model
elif model_type == "koalpaca-polyglot":
return koalpaca.load_model
elif model_type == "kullm-polyglot":
return kullm.load_model
elif model_type == "flan-alpaca":
return flan_alpaca.load_model
elif model_type == "camel":
return camel.load_model
elif model_type == "t5-vicuna":
return t5_vicuna.load_model
elif model_type == "stable-vicuna":
return vicuna.load_model
elif model_type == "starchat":
return starchat.load_model
elif model_type == "mpt":
return mpt.load_model
elif model_type == "redpajama" or \
model_type == "redpajama-instruct":
return redpajama.load_model
elif model_type == "vicuna":
return vicuna.load_model
elif model_type == "evolinstruct-vicuna":
return alpaca.load_model
elif model_type == "alpacoom":
return bloom.load_model
elif model_type == "baize":
return baize.load_model
elif model_type == "guanaco":
return guanaco.load_model
elif model_type == "falcon" or model_type == "wizard-falcon":
return falcon.load_model
elif model_type == "replit-instruct":
return replit.load_model
elif model_type == "airoboros":
return airoboros.load_model
elif model_type == "samantha-vicuna":
return samantha_vicuna.load_model
else:
return None
def get_generation_config(path):
with open(path, 'rb') as f:
generation_config = yaml.safe_load(f.read())
generation_config = generation_config["generation_config"]
return GenerationConfig(**generation_config), generation_config
def get_constraints_config(path):
with open(path, 'rb') as f:
constraints_config = yaml.safe_load(f.read())
return ConstraintsConfig(**constraints_config), constraints_config["constraints"]