FIRE / src /serve /api_provider.py
zhangbofei
feat: change to fstchat
6dc0c9c
raw
history blame
32.7 kB
"""Call API providers."""
import json
import os
import random
import re
from typing import Optional
import time
import requests
from fastchat.utils import build_logger
logger = build_logger("gradio_web_server", "gradio_web_server.log")
def get_api_provider_stream_iter(
conv,
model_name,
model_api_dict,
temperature,
top_p,
max_new_tokens,
state,
):
if model_api_dict["api_type"] == "openai":
if model_api_dict["vision-arena"]:
prompt = conv.to_openai_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = openai_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "openai_assistant":
last_prompt = conv.messages[-2][1]
stream_iter = openai_assistant_api_stream_iter(
state,
last_prompt,
assistant_id=model_api_dict["assistant_id"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "anthropic":
if model_api_dict["vision-arena"]:
prompt = conv.to_anthropic_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = anthropic_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "anthropic_message":
if model_api_dict["vision-arena"]:
prompt = conv.to_anthropic_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = anthropic_message_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "anthropic_message_vertex":
if model_api_dict["vision-arena"]:
prompt = conv.to_anthropic_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = anthropic_message_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
vertex_ai=True,
)
elif model_api_dict["api_type"] == "gemini":
prompt = conv.to_gemini_api_messages()
stream_iter = gemini_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "bard":
prompt = conv.to_openai_api_messages()
stream_iter = bard_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "mistral":
prompt = conv.to_openai_api_messages()
stream_iter = mistral_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "nvidia":
prompt = conv.to_openai_api_messages()
stream_iter = nvidia_api_stream_iter(
model_name,
prompt,
temperature,
top_p,
max_new_tokens,
model_api_dict["api_base"],
)
elif model_api_dict["api_type"] == "ai2":
prompt = conv.to_openai_api_messages()
stream_iter = ai2_api_stream_iter(
model_name,
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "vertex":
prompt = conv.to_vertex_api_messages()
stream_iter = vertex_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "yandexgpt":
# note: top_p parameter is unused by yandexgpt
messages = []
if conv.system_message:
messages.append({"role": "system", "text": conv.system_message})
messages += [
{"role": role, "text": text}
for role, text in conv.messages
if text is not None
]
fixed_temperature = model_api_dict.get("fixed_temperature")
if fixed_temperature is not None:
temperature = fixed_temperature
stream_iter = yandexgpt_api_stream_iter(
model_name=model_api_dict["model_name"],
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict.get("api_key"),
folder_id=model_api_dict.get("folder_id"),
)
elif model_api_dict["api_type"] == "cohere":
messages = conv.to_openai_api_messages()
stream_iter = cohere_api_stream_iter(
client_name=model_api_dict.get("client_name", "FastChat"),
model_id=model_api_dict["model_name"],
messages=messages,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "reka":
messages = conv.to_reka_api_messages()
stream_iter = reka_api_stream_iter(
model_name=model_api_dict["model_name"],
messages=messages,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
else:
raise NotImplementedError()
return stream_iter
def openai_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_base=None,
api_key=None,
):
import openai
api_key = api_key or os.environ["OPENAI_API_KEY"]
if "azure" in model_name:
client = openai.AzureOpenAI(
api_version="2023-07-01-preview",
azure_endpoint=api_base or "https://api.openai.com/v1",
api_key=api_key,
)
else:
client = openai.OpenAI(
base_url=api_base or "https://api.openai.com/v1",
api_key=api_key,
timeout=180,
)
# Make requests for logging
text_messages = []
for message in messages:
if type(message["content"]) == str: # text-only model
text_messages.append(message)
else: # vision model
filtered_content_list = [
content for content in message["content"] if content["type"] == "text"
]
text_messages.append(
{"role": message["role"], "content": filtered_content_list}
)
gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
res = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
stream=True,
)
text = ""
for chunk in res:
if len(chunk.choices) > 0:
text += chunk.choices[0].delta.content or ""
data = {
"text": text,
"error_code": 0,
}
yield data
def upload_openai_file_to_gcs(file_id):
import openai
from google.cloud import storage
storage_client = storage.Client()
file = openai.files.content(file_id)
# upload file to GCS
bucket = storage_client.get_bucket("arena_user_content")
blob = bucket.blob(f"{file_id}")
blob.upload_from_string(file.read())
blob.make_public()
return blob.public_url
def openai_assistant_api_stream_iter(
state,
prompt,
assistant_id,
api_key=None,
):
import openai
import base64
api_key = api_key or os.environ["OPENAI_API_KEY"]
client = openai.OpenAI(base_url="https://api.openai.com/v1", api_key=api_key)
if state.oai_thread_id is None:
logger.info("==== create thread ====")
thread = client.beta.threads.create()
state.oai_thread_id = thread.id
logger.info(f"==== thread_id ====\n{state.oai_thread_id}")
thread_message = client.beta.threads.messages.with_raw_response.create(
state.oai_thread_id,
role="user",
content=prompt,
timeout=3,
)
# logger.info(f"header {thread_message.headers}")
thread_message = thread_message.parse()
# Make requests
gen_params = {
"assistant_id": assistant_id,
"thread_id": state.oai_thread_id,
"message": prompt,
}
logger.info(f"==== request ====\n{gen_params}")
res = requests.post(
f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"OpenAI-Beta": "assistants=v1",
},
json={"assistant_id": assistant_id, "stream": True},
timeout=30,
stream=True,
)
list_of_text = []
list_of_raw_text = []
offset_idx = 0
full_ret_text = ""
idx_mapping = {}
for line in res.iter_lines():
if not line:
continue
data = line.decode("utf-8")
# logger.info("data:", data)
if data.endswith("[DONE]"):
break
if data.startswith("event"):
event = data.split(":")[1].strip()
if event == "thread.message.completed":
offset_idx += len(list_of_text)
continue
data = json.loads(data[6:])
if data.get("status") == "failed":
yield {
"text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}",
"error_code": 1,
}
return
if data.get("status") == "completed":
logger.info(f"[debug]: {data}")
if data["object"] != "thread.message.delta":
continue
for delta in data["delta"]["content"]:
text_index = delta["index"] + offset_idx
if len(list_of_text) <= text_index:
list_of_text.append("")
list_of_raw_text.append("")
text = list_of_text[text_index]
raw_text = list_of_raw_text[text_index]
if delta["type"] == "text":
# text, url_citation or file_path
content = delta["text"]
if "annotations" in content and len(content["annotations"]) > 0:
annotations = content["annotations"]
cur_offset = 0
raw_text_copy = raw_text
for anno in annotations:
if anno["type"] == "url_citation":
anno_text = anno["text"]
if anno_text not in idx_mapping:
continue
citation_number = idx_mapping[anno_text]
start_idx = anno["start_index"] + cur_offset
end_idx = anno["end_index"] + cur_offset
url = anno["url_citation"]["url"]
citation = f" [[{citation_number}]]({url})"
raw_text_copy = (
raw_text_copy[:start_idx]
+ citation
+ raw_text_copy[end_idx:]
)
cur_offset += len(citation) - (end_idx - start_idx)
elif anno["type"] == "file_path":
file_public_url = upload_openai_file_to_gcs(
anno["file_path"]["file_id"]
)
raw_text_copy = raw_text_copy.replace(
anno["text"], f"{file_public_url}"
)
text = raw_text_copy
else:
text_content = content["value"]
raw_text += text_content
# re-index citation number
pattern = r"【\d+】"
matches = re.findall(pattern, content["value"])
if len(matches) > 0:
for match in matches:
if match not in idx_mapping:
idx_mapping[match] = len(idx_mapping) + 1
citation_number = idx_mapping[match]
text_content = text_content.replace(
match, f" [{citation_number}]"
)
text += text_content
# yield {"text": text, "error_code": 0}
elif delta["type"] == "image_file":
image_public_url = upload_openai_file_to_gcs(
delta["image_file"]["file_id"]
)
# raw_text += f"![image]({image_public_url})"
text += f"![image]({image_public_url})"
list_of_text[text_index] = text
list_of_raw_text[text_index] = raw_text
full_ret_text = "\n".join(list_of_text)
yield {"text": full_ret_text, "error_code": 0}
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
import anthropic
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
# Make requests
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
res = c.completions.create(
prompt=prompt,
stop_sequences=[anthropic.HUMAN_PROMPT],
max_tokens_to_sample=max_new_tokens,
temperature=temperature,
top_p=top_p,
model=model_name,
stream=True,
)
text = ""
for chunk in res:
text += chunk.completion
data = {
"text": text,
"error_code": 0,
}
yield data
def anthropic_message_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
vertex_ai=False,
):
import anthropic
if vertex_ai:
client = anthropic.AnthropicVertex(
region=os.environ["GCP_LOCATION"],
project_id=os.environ["GCP_PROJECT_ID"],
max_retries=5,
)
else:
client = anthropic.Anthropic(
api_key=os.environ["ANTHROPIC_API_KEY"],
max_retries=5,
)
text_messages = []
for message in messages:
if type(message["content"]) == str: # text-only model
text_messages.append(message)
else: # vision model
filtered_content_list = [
content for content in message["content"] if content["type"] == "text"
]
text_messages.append(
{"role": message["role"], "content": filtered_content_list}
)
# Make requests for logging
gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
system_prompt = ""
if messages[0]["role"] == "system":
if type(messages[0]["content"]) == dict:
system_prompt = messages[0]["content"]["text"]
elif type(messages[0]["content"]) == str:
system_prompt = messages[0]["content"]
# remove system prompt
messages = messages[1:]
text = ""
with client.messages.stream(
temperature=temperature,
top_p=top_p,
max_tokens=max_new_tokens,
messages=messages,
model=model_name,
system=system_prompt,
) as stream:
for chunk in stream.text_stream:
text += chunk
data = {
"text": text,
"error_code": 0,
}
yield data
def gemini_api_stream_iter(
model_name, messages, temperature, top_p, max_new_tokens, api_key=None
):
import google.generativeai as genai # pip install google-generativeai
if api_key is None:
api_key = os.environ["GEMINI_API_KEY"]
genai.configure(api_key=api_key)
generation_config = {
"temperature": temperature,
"max_output_tokens": max_new_tokens,
"top_p": top_p,
}
params = {
"model": model_name,
"prompt": messages,
}
params.update(generation_config)
logger.info(f"==== request ====\n{params}")
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]
history = []
system_prompt = None
for message in messages[:-1]:
if message["role"] == "system":
system_prompt = message["content"]
continue
history.append({"role": message["role"], "parts": message["content"]})
model = genai.GenerativeModel(
model_name=model_name,
system_instruction=system_prompt,
generation_config=generation_config,
safety_settings=safety_settings,
)
convo = model.start_chat(history=history)
response = convo.send_message(messages[-1]["content"], stream=True)
try:
text = ""
for chunk in response:
text += chunk.candidates[0].content.parts[0].text
data = {
"text": text,
"error_code": 0,
}
yield data
except Exception as e:
logger.error(f"==== error ====\n{e}")
reason = chunk.candidates
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}
def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None):
del top_p # not supported
del temperature # not supported
if api_key is None:
api_key = os.environ["BARD_API_KEY"]
# convert conv to conv_bard
conv_bard = []
for turn in conv:
if turn["role"] == "user":
conv_bard.append({"author": "0", "content": turn["content"]})
elif turn["role"] == "assistant":
conv_bard.append({"author": "1", "content": turn["content"]})
else:
raise ValueError(f"Unsupported role: {turn['role']}")
params = {
"model": model_name,
"prompt": conv_bard,
}
logger.info(f"==== request ====\n{params}")
try:
res = requests.post(
f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}",
json={
"prompt": {
"messages": conv_bard,
},
},
timeout=30,
)
except Exception as e:
logger.error(f"==== error ====\n{e}")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}.",
"error_code": 1,
}
if res.status_code != 200:
logger.error(f"==== error ==== ({res.status_code}): {res.text}")
yield {
"text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.",
"error_code": 1,
}
response_json = res.json()
if "candidates" not in response_json:
logger.error(f"==== error ==== response blocked: {response_json}")
reason = response_json["filters"][0]["reason"]
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}
response = response_json["candidates"][0]["content"]
pos = 0
while pos < len(response):
# simulate token streaming
pos += random.randint(3, 6)
time.sleep(0.002)
data = {
"text": response[:pos],
"error_code": 0,
}
yield data
def ai2_api_stream_iter(
model_name,
model_id,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
api_base=None,
):
# get keys and needed values
ai2_key = api_key or os.environ.get("AI2_API_KEY")
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
# AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
# https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
if temperature == 0.0 and top_p < 1.0:
raise ValueError("top_p must be 1 when temperature is 0.0")
res = requests.post(
api_base,
stream=True,
headers={"Authorization": f"Bearer {ai2_key}"},
json={
"model_id": model_id,
# This input format is specific to the Tulu2 model. Other models
# may require different input formats. See the model's schema
# documentation on InferD for more information.
"input": {
"messages": messages,
"opts": {
"max_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"logprobs": 1, # increase for more choices
},
},
},
timeout=5,
)
if res.status_code != 200:
logger.error(f"unexpected response ({res.status_code}): {res.text}")
raise ValueError("unexpected response from InferD", res)
text = ""
for line in res.iter_lines():
if line:
part = json.loads(line)
if "result" in part and "output" in part["result"]:
for t in part["result"]["output"]["text"]:
text += t
else:
logger.error(f"unexpected part: {part}")
raise ValueError("empty result in InferD response")
data = {
"text": text,
"error_code": 0,
}
yield data
def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
api_key = os.environ["MISTRAL_API_KEY"]
client = MistralClient(api_key=api_key, timeout=5)
# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
new_messages = [
ChatMessage(role=message["role"], content=message["content"])
for message in messages
]
res = client.chat_stream(
model=model_name,
temperature=temperature,
messages=new_messages,
max_tokens=max_new_tokens,
top_p=top_p,
)
text = ""
for chunk in res:
if chunk.choices[0].delta.content is not None:
text += chunk.choices[0].delta.content
data = {
"text": text,
"error_code": 0,
}
yield data
def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base):
api_key = os.environ["NVIDIA_API_KEY"]
headers = {
"Authorization": f"Bearer {api_key}",
"accept": "text/event-stream",
"content-type": "application/json",
}
# nvidia api does not accept 0 temperature
if temp == 0.0:
temp = 0.000001
payload = {
"messages": messages,
"temperature": temp,
"top_p": top_p,
"max_tokens": max_tokens,
"seed": 42,
"stream": True,
}
logger.info(f"==== request ====\n{payload}")
response = requests.post(
api_base, headers=headers, json=payload, stream=True, timeout=1
)
text = ""
for line in response.iter_lines():
if line:
data = line.decode("utf-8")
if data.endswith("[DONE]"):
break
data = json.loads(data[6:])["choices"][0]["delta"]["content"]
text += data
yield {"text": text, "error_code": 0}
def yandexgpt_api_stream_iter(
model_name, messages, temperature, max_tokens, api_base, api_key, folder_id
):
api_key = api_key or os.environ["YANDEXGPT_API_KEY"]
headers = {
"Authorization": f"Api-Key {api_key}",
"content-type": "application/json",
}
payload = {
"modelUri": f"gpt://{folder_id}/{model_name}",
"completionOptions": {
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
},
"messages": messages,
}
logger.info(f"==== request ====\n{payload}")
# https://llm.api.cloud.yandex.net/foundationModels/v1/completion
response = requests.post(
api_base, headers=headers, json=payload, stream=True, timeout=60
)
text = ""
for line in response.iter_lines():
if line:
data = json.loads(line.decode("utf-8"))
data = data["result"]
top_alternative = data["alternatives"][0]
text = top_alternative["message"]["text"]
yield {"text": text, "error_code": 0}
status = top_alternative["status"]
if status in (
"ALTERNATIVE_STATUS_FINAL",
"ALTERNATIVE_STATUS_TRUNCATED_FINAL",
):
break
def cohere_api_stream_iter(
client_name: str,
model_id: str,
messages: list,
temperature: Optional[
float
] = None, # The SDK or API handles None for all parameters following
top_p: Optional[float] = None,
max_new_tokens: Optional[int] = None,
api_key: Optional[str] = None, # default is env var CO_API_KEY
api_base: Optional[str] = None,
):
import cohere
OPENAI_TO_COHERE_ROLE_MAP = {
"user": "User",
"assistant": "Chatbot",
"system": "System",
}
client = cohere.Client(
api_key=api_key,
base_url=api_base,
client_name=client_name,
)
# prepare and log requests
chat_history = [
dict(
role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]], message=message["content"]
)
for message in messages[:-1]
]
actual_prompt = messages[-1]["content"]
gen_params = {
"model": model_id,
"messages": messages,
"chat_history": chat_history,
"prompt": actual_prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
# make request and stream response
res = client.chat_stream(
message=actual_prompt,
chat_history=chat_history,
model=model_id,
temperature=temperature,
max_tokens=max_new_tokens,
p=top_p,
)
try:
text = ""
for streaming_item in res:
if streaming_item.event_type == "text-generation":
text += streaming_item.text
yield {"text": text, "error_code": 0}
except cohere.core.ApiError as e:
logger.error(f"==== error from cohere api: {e} ====")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}",
"error_code": 1,
}
def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
import vertexai
from vertexai import generative_models
from vertexai.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
project_id = os.environ.get("GCP_PROJECT_ID", None)
location = os.environ.get("GCP_LOCATION", None)
vertexai.init(project=project_id, location=location)
text_messages = []
for message in messages:
if type(message) == str:
text_messages.append(message)
gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
safety_settings = [
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
]
generator = GenerativeModel(model_name).generate_content(
messages,
stream=True,
generation_config=GenerationConfig(
top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature
),
safety_settings=safety_settings,
)
ret = ""
for chunk in generator:
# NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129
ret += chunk.candidates[0].content.parts[0]._raw_part.text
# ret += chunk.text
data = {
"text": ret,
"error_code": 0,
}
yield data
def reka_api_stream_iter(
model_name: str,
messages: list,
temperature: Optional[
float
] = None, # The SDK or API handles None for all parameters following
top_p: Optional[float] = None,
max_new_tokens: Optional[int] = None,
api_key: Optional[str] = None, # default is env var CO_API_KEY
api_base: Optional[str] = None,
):
api_key = api_key or os.environ["REKA_API_KEY"]
use_search_engine = False
if "-online" in model_name:
model_name = model_name.replace("-online", "")
use_search_engine = True
request = {
"model_name": model_name,
"conversation_history": messages,
"temperature": temperature,
"request_output_len": max_new_tokens,
"runtime_top_p": top_p,
"stream": True,
"use_search_engine": use_search_engine,
}
# Make requests for logging
text_messages = []
for message in messages:
text_messages.append({"type": message["type"], "text": message["text"]})
logged_request = dict(request)
logged_request["conversation_history"] = text_messages
logger.info(f"==== request ====\n{logged_request}")
response = requests.post(
api_base,
stream=True,
json=request,
headers={
"X-Api-Key": api_key,
},
)
if response.status_code != 200:
error_message = response.text
logger.error(f"==== error from reka api: {error_message} ====")
yield {
"text": f"**API REQUEST ERROR** Reason: {error_message}",
"error_code": 1,
}
return
for line in response.iter_lines():
line = line.decode("utf8")
if not line.startswith("data: "):
continue
gen = json.loads(line[6:])
yield {"text": gen["text"], "error_code": 0}