Spaces:
Paused
Paused
| import hashlib | |
| import os | |
| import sys | |
| import shutil | |
| from functools import wraps, partial | |
| import pytest | |
| if os.path.dirname('src') not in sys.path: | |
| sys.path.append('src') | |
| os.environ['HARD_ASSERTS'] = "1" | |
| from src.utils import call_subprocess_onetask, makedirs, FakeTokenizer | |
| def get_inf_port(): | |
| if os.getenv('HOST') is not None: | |
| inf_port = os.environ['HOST'].split(':')[-1] | |
| elif os.getenv('GRADIO_SERVER_PORT') is not None: | |
| inf_port = os.environ['GRADIO_SERVER_PORT'] | |
| else: | |
| inf_port = str(7860) | |
| return int(inf_port) | |
| def get_inf_server(): | |
| if os.getenv('HOST') is not None: | |
| inf_server = os.environ['HOST'] | |
| elif os.getenv('GRADIO_SERVER_PORT') is not None: | |
| inf_server = "http://localhost:%s" % os.environ['GRADIO_SERVER_PORT'] | |
| else: | |
| raise ValueError("Expect tests to set HOST or GRADIO_SERVER_PORT") | |
| return inf_server | |
| def get_mods(): | |
| testtotalmod = int(os.getenv('TESTMODULOTOTAL', '1')) | |
| testmod = int(os.getenv('TESTMODULO', '0')) | |
| return testtotalmod, testmod | |
| def do_skip_test(name): | |
| """ | |
| Control if skip test. note that skipping all tests does not fail, doing no tests is what fails | |
| :param name: | |
| :return: | |
| """ | |
| testtotalmod, testmod = get_mods() | |
| return int(get_sha(name), 16) % testtotalmod != testmod | |
| def wrap_test_forked(func): | |
| """Decorate a function to test, call in subprocess""" | |
| def f(*args, **kwargs): | |
| # automatically list or set, so can globally control server ports or host for all tests | |
| gradio_port = os.environ['GRADIO_SERVER_PORT'] = os.getenv('GRADIO_SERVER_PORT', str(7860)) | |
| gradio_port = int(gradio_port) | |
| # testtotalmod, testmod = get_mods() | |
| # gradio_port += testmod | |
| os.environ['HOST'] = os.getenv('HOST', "http://localhost:%s" % gradio_port) | |
| pytest_name = get_test_name() | |
| if do_skip_test(pytest_name): | |
| # Skipping is based on raw name, so deterministic | |
| pytest.skip("[%s] TEST SKIPPED due to TESTMODULO" % pytest_name) | |
| func_new = partial(call_subprocess_onetask, func, args, kwargs) | |
| return run_test(func_new) | |
| return f | |
| def run_test(func, *args, **kwargs): | |
| return func(*args, **kwargs) | |
| def get_sha(value): | |
| return hashlib.md5(str(value).encode('utf-8')).hexdigest() | |
| def sanitize_filename(name): | |
| """ | |
| Sanitize file *base* names. Also used to generation valid class names. | |
| :param name: | |
| :return: | |
| """ | |
| bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^'] | |
| for char in bad_chars: | |
| name = name.replace(char, "_") | |
| length = len(name) | |
| file_length_limit = 250 # bit smaller than 256 for safety | |
| sha_length = 32 | |
| real_length_limit = file_length_limit - (sha_length + 2) | |
| if length > file_length_limit: | |
| sha = get_sha(name) | |
| half_real_length_limit = max(1, int(real_length_limit / 2)) | |
| name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length] | |
| return name | |
| def get_test_name(): | |
| tn = os.environ['PYTEST_CURRENT_TEST'].split(':')[-1] | |
| tn = "_".join(tn.split(' ')[:-1]) # skip (call) at end | |
| return sanitize_filename(tn) | |
| def make_user_path_test(): | |
| import os | |
| import shutil | |
| user_path = makedirs('user_path_test', use_base=True) | |
| if os.path.isdir(user_path): | |
| shutil.rmtree(user_path) | |
| user_path = makedirs('user_path_test', use_base=True) | |
| db_dir = "db_dir_UserData" | |
| db_dir = makedirs(db_dir, use_base=True) | |
| if os.path.isdir(db_dir): | |
| shutil.rmtree(db_dir) | |
| db_dir = makedirs(db_dir, use_base=True) | |
| shutil.copy('data/pexels-evg-kowalievska-1170986_small.jpg', user_path) | |
| shutil.copy('README.md', user_path) | |
| shutil.copy('docs/FAQ.md', user_path) | |
| return user_path | |
| def get_llama(llama_type=2): | |
| from huggingface_hub import hf_hub_download | |
| # FIXME: Pass into main() | |
| if llama_type == 1: | |
| file = 'ggml-model-q4_0_7b.bin' | |
| dest = 'models/7B/' | |
| prompt_type = 'plain' | |
| elif llama_type == 2: | |
| file = 'WizardLM-7B-uncensored.ggmlv3.q8_0.bin' | |
| dest = './' | |
| prompt_type = 'wizard2' | |
| else: | |
| raise ValueError("unknown llama_type=%s" % llama_type) | |
| makedirs(dest, exist_ok=True) | |
| full_path = os.path.join(dest, file) | |
| if not os.path.isfile(full_path): | |
| # True for case when locally already logged in with correct token, so don't have to set key | |
| token = os.getenv('HUGGING_FACE_HUB_TOKEN', True) | |
| out_path = hf_hub_download('h2oai/ggml', file, token=token, repo_type='model') | |
| # out_path will look like '/home/jon/.cache/huggingface/hub/models--h2oai--ggml/snapshots/57e79c71bb0cee07e3e3ffdea507105cd669fa96/ggml-model-q4_0_7b.bin' | |
| shutil.copy(out_path, dest) | |
| return prompt_type, full_path | |
| def kill_weaviate(db_type): | |
| """ | |
| weaviate launches detatched server, which accumulates entries in db, but we want to start freshly | |
| """ | |
| if db_type == 'weaviate': | |
| os.system('pkill --signal 9 -f weaviate-embedded/weaviate') | |
| def count_tokens_llm(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', tokenizer=None): | |
| import time | |
| if tokenizer is None: | |
| assert base_model is not None | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| t0 = time.time() | |
| a = len(tokenizer(prompt)['input_ids']) | |
| print('llm: ', a, time.time() - t0) | |
| return dict(llm=a) | |
| def count_tokens(prompt, base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b'): | |
| tokenizer = FakeTokenizer() | |
| num_tokens = tokenizer.num_tokens_from_string(prompt) | |
| print(num_tokens) | |
| from transformers import AutoTokenizer | |
| t = AutoTokenizer.from_pretrained("distilgpt2") | |
| llm_tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| from InstructorEmbedding import INSTRUCTOR | |
| emb = INSTRUCTOR('hkunlp/instructor-large') | |
| import nltk | |
| def nltkTokenize(text): | |
| words = nltk.word_tokenize(text) | |
| return words | |
| import re | |
| WORD = re.compile(r'\w+') | |
| def regTokenize(text): | |
| words = WORD.findall(text) | |
| return words | |
| counts = {} | |
| import time | |
| t0 = time.time() | |
| a = len(regTokenize(prompt)) | |
| print('reg: ', a, time.time() - t0) | |
| counts.update(dict(reg=a)) | |
| t0 = time.time() | |
| a = len(nltkTokenize(prompt)) | |
| print('nltk: ', a, time.time() - t0) | |
| counts.update(dict(nltk=a)) | |
| t0 = time.time() | |
| a = len(t(prompt)['input_ids']) | |
| print('tiktoken: ', a, time.time() - t0) | |
| counts.update(dict(tiktoken=a)) | |
| t0 = time.time() | |
| a = len(llm_tokenizer(prompt)['input_ids']) | |
| print('llm: ', a, time.time() - t0) | |
| counts.update(dict(llm=a)) | |
| t0 = time.time() | |
| a = emb.tokenize([prompt])['input_ids'].shape[1] | |
| print('instructor-large: ', a, time.time() - t0) | |
| counts.update(dict(instructor=a)) | |
| return counts | |