MobiLlama / fastchat /serve /api_provider.py
Ashmal's picture
Upload folder using huggingface_hub
5472531 verified
raw
history blame
6.96 kB
"""Call API providers."""
from json import loads
import os
import random
import time
from fastchat.utils import build_logger
from fastchat.constants import WORKER_API_TIMEOUT
logger = build_logger("gradio_web_server", "gradio_web_server.log")
def openai_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_base=None,
api_key=None,
):
import openai
is_azure = False
if "azure" in model_name:
is_azure = True
openai.api_type = "azure"
openai.api_version = "2023-07-01-preview"
else:
openai.api_type = "open_ai"
openai.api_version = None
openai.api_base = api_base or "https://api.openai.com/v1"
openai.api_key = api_key or os.environ["OPENAI_API_KEY"]
if model_name == "gpt-4-turbo":
model_name = "gpt-4-1106-preview"
# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
if is_azure:
res = openai.ChatCompletion.create(
engine=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
stream=True,
)
else:
res = openai.ChatCompletion.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
stream=True,
)
text = ""
for chunk in res:
if len(chunk["choices"]) > 0:
text += chunk["choices"][0]["delta"].get("content", "")
data = {
"text": text,
"error_code": 0,
}
yield data
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
import anthropic
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
# Make requests
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
res = c.completions.create(
prompt=prompt,
stop_sequences=[anthropic.HUMAN_PROMPT],
max_tokens_to_sample=max_new_tokens,
temperature=temperature,
top_p=top_p,
model=model_name,
stream=True,
)
text = ""
for chunk in res:
text += chunk.completion
data = {
"text": text,
"error_code": 0,
}
yield data
def init_palm_chat(model_name):
import vertexai # pip3 install google-cloud-aiplatform
from vertexai.preview.language_models import ChatModel
from vertexai.preview.generative_models import GenerativeModel
project_id = os.environ["GCP_PROJECT_ID"]
location = "us-central1"
vertexai.init(project=project_id, location=location)
if model_name in ["palm-2"]:
# According to release note, "chat-bison@001" is PaLM 2 for chat.
# https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023
model_name = "chat-bison@001"
chat_model = ChatModel.from_pretrained(model_name)
chat = chat_model.start_chat(examples=[])
elif model_name in ["gemini-pro"]:
model = GenerativeModel(model_name)
chat = model.start_chat()
return chat
def palm_api_stream_iter(model_name, chat, message, temperature, top_p, max_new_tokens):
if model_name in ["gemini-pro"]:
max_new_tokens = max_new_tokens * 2
parameters = {
"temperature": temperature,
"top_p": top_p,
"max_output_tokens": max_new_tokens,
}
gen_params = {
"model": model_name,
"prompt": message,
}
gen_params.update(parameters)
if model_name == "palm-2":
response = chat.send_message(message, **parameters)
else:
response = chat.send_message(message, generation_config=parameters, stream=True)
logger.info(f"==== request ====\n{gen_params}")
try:
text = ""
for chunk in response:
text += chunk.text
data = {
"text": text,
"error_code": 0,
}
yield data
except Exception as e:
logger.error(f"==== error ====\n{e}")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}\nPlease try again or increase the number of max tokens.",
"error_code": 1,
}
yield data
def ai2_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
api_base=None,
):
from requests import post
# get keys and needed values
ai2_key = api_key or os.environ.get("AI2_API_KEY")
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
model_id = "mod_01hhgcga70c91402r9ssyxekan"
# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
# AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
# https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
if temperature == 0.0 and top_p < 1.0:
raise ValueError("top_p must be 1 when temperature is 0.0")
res = post(
api_base,
stream=True,
headers={"Authorization": f"Bearer {ai2_key}"},
json={
"model_id": model_id,
# This input format is specific to the Tulu2 model. Other models
# may require different input formats. See the model's schema
# documentation on InferD for more information.
"input": {
"messages": messages,
"opts": {
"max_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"logprobs": 1, # increase for more choices
},
},
},
)
if res.status_code != 200:
logger.error(f"unexpected response ({res.status_code}): {res.text}")
raise ValueError("unexpected response from InferD", res)
text = ""
for line in res.iter_lines():
if line:
part = loads(line)
if "result" in part and "output" in part["result"]:
for t in part["result"]["output"]["text"]:
text += t
else:
logger.error(f"unexpected part: {part}")
raise ValueError("empty result in InferD response")
data = {
"text": text,
"error_code": 0,
}
yield data