File size: 4,969 Bytes
b585c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import platform

import pytest

from h2ogpt_client import Client

platform.python_version()


@pytest.fixture
def client(server_url, h2ogpt_key) -> Client:
    return Client(server_url, h2ogpt_key=h2ogpt_key)


def _create_text_completion(client):
    model = client.models.list()[-1]
    return client.text_completion.create(model=model)


@pytest.mark.asyncio
async def test_text_completion(client):
    text_completion = _create_text_completion(client)
    response = await text_completion.complete(prompt="Hello world")
    assert response
    print(response)


@pytest.mark.asyncio
async def test_text_completion_stream(client):
    text_completion = _create_text_completion(client)
    response = await text_completion.complete(
        prompt="Write a poem about the Amazon rainforest. End it with an emoji.",
        enable_streaming=True,
    )
    async for token in response:
        assert token
        print(token, end="")


def test_text_completion_sync(client):
    text_completion = _create_text_completion(client)
    response = text_completion.complete_sync(prompt="Hello world")
    assert response
    print(response)


def test_text_completion_sync_stream(client):
    text_completion = _create_text_completion(client)
    response = text_completion.complete_sync(
        prompt="Write a poem about the Amazon rainforest. End it with an emoji.",
        enable_streaming=True,
    )
    for token in response:
        assert token
        print(token, end="")


def _create_chat_completion(client):
    model = client.models.list()[-1]
    return client.chat_completion.create(model=model)


@pytest.mark.asyncio
async def test_chat_completion(client):
    chat_completion = _create_chat_completion(client)

    chat1 = await chat_completion.chat(prompt="Hey!")
    assert chat1["user"] == "Hey!"
    assert chat1["gpt"]

    chat2 = await chat_completion.chat(prompt="What is the capital of USA?")
    assert chat2["user"] == "What is the capital of USA?"
    assert chat2["gpt"]

    chat3 = await chat_completion.chat(prompt="What is the population in there?")
    assert chat3["user"] == "What is the population in there?"
    assert chat3["gpt"]

    chat_history = chat_completion.chat_history()
    assert chat_history == [chat1, chat2, chat3]
    print(chat_history)


def test_chat_completion_sync(client):
    chat_completion = _create_chat_completion(client)

    chat1 = chat_completion.chat_sync(prompt="What is UNESCO?")
    assert chat1["user"] == "What is UNESCO?"
    assert chat1["gpt"]

    chat2 = chat_completion.chat_sync(prompt="Is it a part of the UN?")
    assert chat2["user"] == "Is it a part of the UN?"
    assert chat2["gpt"]

    chat3 = chat_completion.chat_sync(prompt="Where is the headquarters?")
    assert chat3["user"] == "Where is the headquarters?"
    assert chat3["gpt"]

    chat_history = chat_completion.chat_history()
    assert chat_history == [chat1, chat2, chat3]
    print(chat_history)


def test_available_models(client):
    models = client.models.list()
    assert len(models)
    print(models)


def test_server_properties(client, server_url):
    assert client.server.address.startswith(server_url)
    assert client.server.hash


def test_parameters_order(client, eval_func_param_names):
    text_completion = client.text_completion.create()
    assert eval_func_param_names == list(text_completion._parameters.keys())
    chat_completion = client.chat_completion.create()
    assert eval_func_param_names == list(chat_completion._parameters.keys())


@pytest.mark.parametrize("local_server", [True, False])
def test_readme_example(local_server):
    # self-contained example used for readme,
    # to be copied to client/README.md if changed, setting local_server = True at first
    import asyncio
    import os

    from h2ogpt_client import Client

    if local_server:
        client = Client("http://0.0.0.0:7860")
    else:
        h2ogpt_key = os.getenv("H2OGPT_KEY") or os.getenv("H2OGPT_H2OGPT_KEY")
        if h2ogpt_key is None:
            return
        # if you have API key for public instance:
        client = Client("https://gpt.h2o.ai", h2ogpt_key=h2ogpt_key)

    # Text completion
    text_completion = client.text_completion.create()
    response = asyncio.run(text_completion.complete("Hello world"))
    print("asyncio text completion response: %s" % response)
    # Text completion: synchronous
    response = text_completion.complete_sync("Hello world")
    print("sync text completion response: %s" % response)

    # Chat completion
    chat_completion = client.chat_completion.create()
    reply = asyncio.run(chat_completion.chat("Hey!"))
    print("asyncio text completion user: %s gpt: %s" % (reply["user"], reply["gpt"]))
    chat_history = chat_completion.chat_history()
    print("chat_history: %s" % chat_history)
    # Chat completion: synchronous
    reply = chat_completion.chat_sync("Hey!")
    print("sync chat completion gpt: %s" % reply["gpt"])