h2ogpt-chatbot2 / client_test.py
elitecode's picture
Duplicate from h2oai/h2ogpt-chatbot2
ffb3860
"""
Client test.
Run server:
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
NOTE: For private models, add --use-auth_token=True
NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
Currently, this will force model to be on a single GPU.
Then run this client as:
python client_test.py
For HF spaces:
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
Result:
Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
For demo:
HOST="https://gpt.h2o.ai" python client_test.py
Result:
Loaded as API: https://gpt.h2o.ai ✔
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
{'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
"""
import ast
import time
import os
import markdown # pip install markdown
import pytest
from bs4 import BeautifulSoup # pip install beautifulsoup4
from enums import DocumentChoices
debug = False
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
def get_client(serialize=True):
from gradio_client import Client
client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize)
if debug:
print(client.view_api(all_endpoints=True))
return client
def get_args(prompt, prompt_type, chat=False, stream_output=False,
max_new_tokens=50,
top_k_docs=3,
langchain_mode='Disabled'):
from collections import OrderedDict
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
iinput='', # only for chat=True
context='',
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
prompt_type=prompt_type,
prompt_dict='',
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=max_new_tokens,
min_new_tokens=0,
early_stopping=False,
max_time=20,
repetition_penalty=1.0,
num_return_sequences=1,
do_sample=True,
chat=chat,
instruction_nochat=prompt if not chat else '',
iinput_nochat='', # only for chat=False
langchain_mode=langchain_mode,
top_k_docs=top_k_docs,
chunk=True,
chunk_size=512,
document_choice=[DocumentChoices.All_Relevant.name],
)
if chat:
# add chatbot output on end. Assumes serialize=False
kwargs.update(dict(chatbot=[]))
return kwargs, list(kwargs.values())
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic():
return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
def run_client_nochat(prompt, prompt_type, max_new_tokens):
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
api_name = '/submit_nochat'
client = get_client(serialize=True)
res = client.predict(
*tuple(args),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
response=md_to_text(res))
print(res_dict)
return res_dict
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic_api():
return run_client_nochat_api(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
client = get_client(serialize=True)
res = client.predict(
str(dict(kwargs)),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
response=md_to_text(ast.literal_eval(res)['response']),
sources=ast.literal_eval(res)['sources'])
print(res_dict)
return res_dict
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic_api_lean():
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
kwargs = dict(instruction_nochat=prompt)
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
client = get_client(serialize=True)
res = client.predict(
str(dict(kwargs)),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'],
response=md_to_text(ast.literal_eval(res)['response']),
sources=ast.literal_eval(res)['sources'])
print(res_dict)
return res_dict
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic_api_lean_morestuff():
return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
def run_client_nochat_api_lean_morestuff(prompt, prompt_type, max_new_tokens):
kwargs = dict(
instruction='',
iinput='',
context='',
stream_output=False,
prompt_type='human_bot',
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
min_new_tokens=0,
early_stopping=False,
max_time=20,
repetition_penalty=1.0,
num_return_sequences=1,
do_sample=True,
chat=False,
instruction_nochat=prompt,
iinput_nochat='',
langchain_mode='Disabled',
top_k_docs=4,
document_choice=['All'],
)
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
client = get_client(serialize=True)
res = client.predict(
str(dict(kwargs)),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'],
response=md_to_text(ast.literal_eval(res)['response']),
sources=ast.literal_eval(res)['sources'])
print(res_dict)
return res_dict
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_chat():
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50,
langchain_mode='Disabled')
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
client = get_client(serialize=False)
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
return run_client(client, prompt, args, kwargs)
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
res = client.predict(*tuple(args), api_name='/instruction')
args[-1] += [res[-1]]
res_dict = kwargs
res_dict['prompt'] = prompt
if not kwargs['stream_output']:
res = client.predict(*tuple(args), api_name='/instruction_bot')
res_dict['response'] = res[0][-1][1]
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
return res_dict, client
else:
job = client.submit(*tuple(args), api_name='/instruction_bot')
res1 = ''
while not job.done():
outputs_list = job.communicator.job.outputs
if outputs_list:
res = job.communicator.job.outputs[-1]
res1 = res[0][-1][-1]
res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
print(res1)
time.sleep(0.1)
full_outputs = job.outputs()
if verbose:
print('job.outputs: %s' % str(full_outputs))
# ensure get ending to avoid race
# -1 means last response if streaming
# 0 means get text_output, ignore exception_text
# 0 means get list within text_output that looks like [[prompt], [answer]]
# 1 means get bot answer, so will have last bot answer
res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
return res_dict, client
def md_to_text(md, do_md_to_text=True):
if not do_md_to_text:
return md
assert md is not None, "Markdown is None"
html = markdown.markdown(md)
soup = BeautifulSoup(html, features='html.parser')
return soup.get_text()
if __name__ == '__main__':
test_client_basic()
test_client_basic_api()
test_client_basic_api_lean()
test_client_basic_api_lean_morestuff()