chatbot / tests /utils.py
kelvin-t-lu's picture
init
dbd2ac6
import hashlib
import os
import sys
import shutil
from functools import wraps, partial
import pytest
if os.path.dirname('src') not in sys.path:
sys.path.append('src')
os.environ['HARD_ASSERTS'] = "1"
from src.utils import call_subprocess_onetask, makedirs, FakeTokenizer
def get_inf_port():
if os.getenv('HOST') is not None:
inf_port = os.environ['HOST'].split(':')[-1]
elif os.getenv('GRADIO_SERVER_PORT') is not None:
inf_port = os.environ['GRADIO_SERVER_PORT']
else:
inf_port = str(7860)
return int(inf_port)
def get_inf_server():
if os.getenv('HOST') is not None:
inf_server = os.environ['HOST']
elif os.getenv('GRADIO_SERVER_PORT') is not None:
inf_server = "http://localhost:%s" % os.environ['GRADIO_SERVER_PORT']
else:
raise ValueError("Expect tests to set HOST or GRADIO_SERVER_PORT")
return inf_server
def get_mods():
testtotalmod = int(os.getenv('TESTMODULOTOTAL', '1'))
testmod = int(os.getenv('TESTMODULO', '0'))
return testtotalmod, testmod
def do_skip_test(name):
"""
Control if skip test. note that skipping all tests does not fail, doing no tests is what fails
:param name:
:return:
"""
testtotalmod, testmod = get_mods()
return int(get_sha(name), 16) % testtotalmod != testmod
def wrap_test_forked(func):
"""Decorate a function to test, call in subprocess"""
@wraps(func)
def f(*args, **kwargs):
# automatically list or set, so can globally control server ports or host for all tests
gradio_port = os.environ['GRADIO_SERVER_PORT'] = os.getenv('GRADIO_SERVER_PORT', str(7860))
gradio_port = int(gradio_port)
# testtotalmod, testmod = get_mods()
# gradio_port += testmod
os.environ['HOST'] = os.getenv('HOST', "http://localhost:%s" % gradio_port)
pytest_name = get_test_name()
if do_skip_test(pytest_name):
# Skipping is based on raw name, so deterministic
pytest.skip("[%s] TEST SKIPPED due to TESTMODULO" % pytest_name)
func_new = partial(call_subprocess_onetask, func, args, kwargs)
return run_test(func_new)
return f
def run_test(func, *args, **kwargs):
return func(*args, **kwargs)
def get_sha(value):
return hashlib.md5(str(value).encode('utf-8')).hexdigest()
def sanitize_filename(name):
"""
Sanitize file *base* names. Also used to generation valid class names.
:param name:
:return:
"""
bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
for char in bad_chars:
name = name.replace(char, "_")
length = len(name)
file_length_limit = 250 # bit smaller than 256 for safety
sha_length = 32
real_length_limit = file_length_limit - (sha_length + 2)
if length > file_length_limit:
sha = get_sha(name)
half_real_length_limit = max(1, int(real_length_limit / 2))
name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]
return name
def get_test_name():
tn = os.environ['PYTEST_CURRENT_TEST'].split(':')[-1]
tn = "_".join(tn.split(' ')[:-1]) # skip (call) at end
return sanitize_filename(tn)
def make_user_path_test():
import os
import shutil
user_path = makedirs('user_path_test', use_base=True)
if os.path.isdir(user_path):
shutil.rmtree(user_path)
user_path = makedirs('user_path_test', use_base=True)
db_dir = "db_dir_UserData"
db_dir = makedirs(db_dir, use_base=True)
if os.path.isdir(db_dir):
shutil.rmtree(db_dir)
db_dir = makedirs(db_dir, use_base=True)
shutil.copy('data/pexels-evg-kowalievska-1170986_small.jpg', user_path)
shutil.copy('README.md', user_path)
shutil.copy('docs/FAQ.md', user_path)
return user_path
def get_llama(llama_type=2):
from huggingface_hub import hf_hub_download
# FIXME: Pass into main()
if llama_type == 1:
file = 'ggml-model-q4_0_7b.bin'
dest = 'models/7B/'
prompt_type = 'plain'
elif llama_type == 2:
file = 'WizardLM-7B-uncensored.ggmlv3.q8_0.bin'
dest = './'
prompt_type = 'wizard2'
else:
raise ValueError("unknown llama_type=%s" % llama_type)
makedirs(dest, exist_ok=True)
full_path = os.path.join(dest, file)
if not os.path.isfile(full_path):
# True for case when locally already logged in with correct token, so don't have to set key
token = os.getenv('HUGGING_FACE_HUB_TOKEN', True)
out_path = hf_hub_download('h2oai/ggml', file, token=token, repo_type='model')
# out_path will look like '/home/jon/.cache/huggingface/hub/models--h2oai--ggml/snapshots/57e79c71bb0cee07e3e3ffdea507105cd669fa96/ggml-model-q4_0_7b.bin'
shutil.copy(out_path, dest)
return prompt_type, full_path
def kill_weaviate(db_type):
"""
weaviate launches detatched server, which accumulates entries in db, but we want to start freshly
"""
if db_type == 'weaviate':
os.system('pkill --signal 9 -f weaviate-embedded/weaviate')
def count_tokens_llm(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', tokenizer=None):
import time
if tokenizer is None:
assert base_model is not None
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
t0 = time.time()
a = len(tokenizer(prompt)['input_ids'])
print('llm: ', a, time.time() - t0)
return dict(llm=a)
def count_tokens(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b'):
tokenizer = FakeTokenizer()
num_tokens = tokenizer.num_tokens_from_string(prompt)
print(num_tokens)
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("distilgpt2")
llm_tokenizer = AutoTokenizer.from_pretrained(base_model)
from InstructorEmbedding import INSTRUCTOR
emb = INSTRUCTOR('hkunlp/instructor-large')
import nltk
def nltkTokenize(text):
words = nltk.word_tokenize(text)
return words
import re
WORD = re.compile(r'\w+')
def regTokenize(text):
words = WORD.findall(text)
return words
counts = {}
import time
t0 = time.time()
a = len(regTokenize(prompt))
print('reg: ', a, time.time() - t0)
counts.update(dict(reg=a))
t0 = time.time()
a = len(nltkTokenize(prompt))
print('nltk: ', a, time.time() - t0)
counts.update(dict(nltk=a))
t0 = time.time()
a = len(t(prompt)['input_ids'])
print('tiktoken: ', a, time.time() - t0)
counts.update(dict(tiktoken=a))
t0 = time.time()
a = len(llm_tokenizer(prompt)['input_ids'])
print('llm: ', a, time.time() - t0)
counts.update(dict(llm=a))
t0 = time.time()
a = emb.tokenize([prompt])['input_ids'].shape[1]
print('instructor-large: ', a, time.time() - t0)
counts.update(dict(instructor=a))
return counts