File size: 4,969 Bytes
b585c7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import platform
import pytest
from h2ogpt_client import Client
platform.python_version()
@pytest.fixture
def client(server_url, h2ogpt_key) -> Client:
return Client(server_url, h2ogpt_key=h2ogpt_key)
def _create_text_completion(client):
model = client.models.list()[-1]
return client.text_completion.create(model=model)
@pytest.mark.asyncio
async def test_text_completion(client):
text_completion = _create_text_completion(client)
response = await text_completion.complete(prompt="Hello world")
assert response
print(response)
@pytest.mark.asyncio
async def test_text_completion_stream(client):
text_completion = _create_text_completion(client)
response = await text_completion.complete(
prompt="Write a poem about the Amazon rainforest. End it with an emoji.",
enable_streaming=True,
)
async for token in response:
assert token
print(token, end="")
def test_text_completion_sync(client):
text_completion = _create_text_completion(client)
response = text_completion.complete_sync(prompt="Hello world")
assert response
print(response)
def test_text_completion_sync_stream(client):
text_completion = _create_text_completion(client)
response = text_completion.complete_sync(
prompt="Write a poem about the Amazon rainforest. End it with an emoji.",
enable_streaming=True,
)
for token in response:
assert token
print(token, end="")
def _create_chat_completion(client):
model = client.models.list()[-1]
return client.chat_completion.create(model=model)
@pytest.mark.asyncio
async def test_chat_completion(client):
chat_completion = _create_chat_completion(client)
chat1 = await chat_completion.chat(prompt="Hey!")
assert chat1["user"] == "Hey!"
assert chat1["gpt"]
chat2 = await chat_completion.chat(prompt="What is the capital of USA?")
assert chat2["user"] == "What is the capital of USA?"
assert chat2["gpt"]
chat3 = await chat_completion.chat(prompt="What is the population in there?")
assert chat3["user"] == "What is the population in there?"
assert chat3["gpt"]
chat_history = chat_completion.chat_history()
assert chat_history == [chat1, chat2, chat3]
print(chat_history)
def test_chat_completion_sync(client):
chat_completion = _create_chat_completion(client)
chat1 = chat_completion.chat_sync(prompt="What is UNESCO?")
assert chat1["user"] == "What is UNESCO?"
assert chat1["gpt"]
chat2 = chat_completion.chat_sync(prompt="Is it a part of the UN?")
assert chat2["user"] == "Is it a part of the UN?"
assert chat2["gpt"]
chat3 = chat_completion.chat_sync(prompt="Where is the headquarters?")
assert chat3["user"] == "Where is the headquarters?"
assert chat3["gpt"]
chat_history = chat_completion.chat_history()
assert chat_history == [chat1, chat2, chat3]
print(chat_history)
def test_available_models(client):
models = client.models.list()
assert len(models)
print(models)
def test_server_properties(client, server_url):
assert client.server.address.startswith(server_url)
assert client.server.hash
def test_parameters_order(client, eval_func_param_names):
text_completion = client.text_completion.create()
assert eval_func_param_names == list(text_completion._parameters.keys())
chat_completion = client.chat_completion.create()
assert eval_func_param_names == list(chat_completion._parameters.keys())
@pytest.mark.parametrize("local_server", [True, False])
def test_readme_example(local_server):
# self-contained example used for readme,
# to be copied to client/README.md if changed, setting local_server = True at first
import asyncio
import os
from h2ogpt_client import Client
if local_server:
client = Client("http://0.0.0.0:7860")
else:
h2ogpt_key = os.getenv("H2OGPT_KEY") or os.getenv("H2OGPT_H2OGPT_KEY")
if h2ogpt_key is None:
return
# if you have API key for public instance:
client = Client("https://gpt.h2o.ai", h2ogpt_key=h2ogpt_key)
# Text completion
text_completion = client.text_completion.create()
response = asyncio.run(text_completion.complete("Hello world"))
print("asyncio text completion response: %s" % response)
# Text completion: synchronous
response = text_completion.complete_sync("Hello world")
print("sync text completion response: %s" % response)
# Chat completion
chat_completion = client.chat_completion.create()
reply = asyncio.run(chat_completion.chat("Hey!"))
print("asyncio text completion user: %s gpt: %s" % (reply["user"], reply["gpt"]))
chat_history = chat_completion.chat_history()
print("chat_history: %s" % chat_history)
# Chat completion: synchronous
reply = chat_completion.chat_sync("Hey!")
print("sync chat completion gpt: %s" % reply["gpt"])
|