MobiLlama / tests /test_openai_api.py
Ashmal's picture
Upload folder using huggingface_hub
5472531 verified
"""
Test the OpenAI compatible server
Launch:
python3 launch_openai_api_test_server.py
"""
import warnings
import openai
from fastchat.utils import run_cmd
openai.api_key = "EMPTY" # Not support yet
openai.base_url = "http://localhost:8000/v1/"
def test_list_models():
model_list = openai.models.list()
names = [x.id for x in model_list.data]
return names
def test_completion(model, logprob):
prompt = "Once upon a time"
completion = openai.completions.create(
model=model,
prompt=prompt,
logprobs=logprob,
max_tokens=64,
temperature=0,
)
print(f"full text: {prompt + completion.choices[0].text}", flush=True)
if completion.choices[0].logprobs is not None:
print(
f"logprobs: {completion.choices[0].logprobs.token_logprobs[:10]}",
flush=True,
)
def test_completion_stream(model):
prompt = "Once upon a time"
res = openai.completions.create(
model=model,
prompt=prompt,
max_tokens=64,
stream=True,
temperature=0,
)
print(prompt, end="")
for chunk in res:
content = chunk.choices[0].text
print(content, end="", flush=True)
print()
def test_embedding(model):
embedding = openai.embeddings.create(model=model, input="Hello world!")
print(f"embedding len: {len(embedding.data[0].embedding)}")
print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}")
def test_chat_completion(model):
completion = openai.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
temperature=0,
)
print(completion.choices[0].message.content)
def test_chat_completion_stream(model):
messages = [{"role": "user", "content": "Hello! What is your name?"}]
res = openai.chat.completions.create(
model=model, messages=messages, stream=True, temperature=0
)
for chunk in res:
try:
content = chunk.choices[0].delta.content
if content is None:
content = ""
except Exception as e:
content = chunk.choices[0].delta.get("content", "")
print(content, end="", flush=True)
print()
def test_openai_curl():
run_cmd("curl http://localhost:8000/v1/models")
run_cmd(
"""
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "vicuna-7b-v1.5",
"messages": [{"role": "user", "content": "Hello! What is your name?"}]
}'
"""
)
run_cmd(
"""
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "vicuna-7b-v1.5",
"prompt": "Once upon a time",
"max_tokens": 41,
"temperature": 0.5
}'
"""
)
run_cmd(
"""
curl http://localhost:8000/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"model": "vicuna-7b-v1.5",
"input": "Hello world!"
}'
"""
)
if __name__ == "__main__":
models = test_list_models()
print(f"models: {models}")
for model in models:
print(f"===== Test {model} ======")
if model in ["fastchat-t5-3b-v1.0"]:
logprob = None
else:
logprob = 1
test_completion(model, logprob)
test_completion_stream(model)
test_chat_completion(model)
test_chat_completion_stream(model)
try:
test_embedding(model)
except openai.APIError as e:
print(f"Embedding error: {e}")
print("===== Test curl =====")
test_openai_curl()