|
import json, copy, types |
|
import os |
|
from enum import Enum |
|
import time |
|
from typing import Callable, Optional, Any, Union |
|
import litellm |
|
from litellm.utils import ModelResponse, get_secret, Usage |
|
from .prompt_templates.factory import prompt_factory, custom_prompt |
|
import httpx |
|
|
|
|
|
class BedrockError(Exception): |
|
def __init__(self, status_code, message): |
|
self.status_code = status_code |
|
self.message = message |
|
self.request = httpx.Request( |
|
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" |
|
) |
|
self.response = httpx.Response(status_code=status_code, request=self.request) |
|
super().__init__( |
|
self.message |
|
) |
|
|
|
|
|
class AmazonTitanConfig: |
|
""" |
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 |
|
|
|
Supported Params for the Amazon Titan models: |
|
|
|
- `maxTokenCount` (integer) max tokens, |
|
- `stopSequences` (string[]) list of stop sequence strings |
|
- `temperature` (float) temperature for model, |
|
- `topP` (int) top p for model |
|
""" |
|
|
|
maxTokenCount: Optional[int] = None |
|
stopSequences: Optional[list] = None |
|
temperature: Optional[float] = None |
|
topP: Optional[int] = None |
|
|
|
def __init__( |
|
self, |
|
maxTokenCount: Optional[int] = None, |
|
stopSequences: Optional[list] = None, |
|
temperature: Optional[float] = None, |
|
topP: Optional[int] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return { |
|
k: v |
|
for k, v in cls.__dict__.items() |
|
if not k.startswith("__") |
|
and not isinstance( |
|
v, |
|
( |
|
types.FunctionType, |
|
types.BuiltinFunctionType, |
|
classmethod, |
|
staticmethod, |
|
), |
|
) |
|
and v is not None |
|
} |
|
|
|
|
|
class AmazonAnthropicConfig: |
|
""" |
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude |
|
|
|
Supported Params for the Amazon / Anthropic models: |
|
|
|
- `max_tokens_to_sample` (integer) max tokens, |
|
- `temperature` (float) model temperature, |
|
- `top_k` (integer) top k, |
|
- `top_p` (integer) top p, |
|
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"], |
|
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" |
|
""" |
|
|
|
max_tokens_to_sample: Optional[int] = litellm.max_tokens |
|
stop_sequences: Optional[list] = None |
|
temperature: Optional[float] = None |
|
top_k: Optional[int] = None |
|
top_p: Optional[int] = None |
|
anthropic_version: Optional[str] = None |
|
|
|
def __init__( |
|
self, |
|
max_tokens_to_sample: Optional[int] = None, |
|
stop_sequences: Optional[list] = None, |
|
temperature: Optional[float] = None, |
|
top_k: Optional[int] = None, |
|
top_p: Optional[int] = None, |
|
anthropic_version: Optional[str] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return { |
|
k: v |
|
for k, v in cls.__dict__.items() |
|
if not k.startswith("__") |
|
and not isinstance( |
|
v, |
|
( |
|
types.FunctionType, |
|
types.BuiltinFunctionType, |
|
classmethod, |
|
staticmethod, |
|
), |
|
) |
|
and v is not None |
|
} |
|
|
|
|
|
class AmazonCohereConfig: |
|
""" |
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command |
|
|
|
Supported Params for the Amazon / Cohere models: |
|
|
|
- `max_tokens` (integer) max tokens, |
|
- `temperature` (float) model temperature, |
|
- `return_likelihood` (string) n/a |
|
""" |
|
|
|
max_tokens: Optional[int] = None |
|
temperature: Optional[float] = None |
|
return_likelihood: Optional[str] = None |
|
|
|
def __init__( |
|
self, |
|
max_tokens: Optional[int] = None, |
|
temperature: Optional[float] = None, |
|
return_likelihood: Optional[str] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return { |
|
k: v |
|
for k, v in cls.__dict__.items() |
|
if not k.startswith("__") |
|
and not isinstance( |
|
v, |
|
( |
|
types.FunctionType, |
|
types.BuiltinFunctionType, |
|
classmethod, |
|
staticmethod, |
|
), |
|
) |
|
and v is not None |
|
} |
|
|
|
|
|
class AmazonAI21Config: |
|
""" |
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra |
|
|
|
Supported Params for the Amazon / AI21 models: |
|
|
|
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`. |
|
|
|
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding. |
|
|
|
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass. |
|
|
|
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional. |
|
|
|
- `frequencyPenalty` (object): Placeholder for frequency penalty object. |
|
|
|
- `presencePenalty` (object): Placeholder for presence penalty object. |
|
|
|
- `countPenalty` (object): Placeholder for count penalty object. |
|
""" |
|
|
|
maxTokens: Optional[int] = None |
|
temperature: Optional[float] = None |
|
topP: Optional[float] = None |
|
stopSequences: Optional[list] = None |
|
frequencePenalty: Optional[dict] = None |
|
presencePenalty: Optional[dict] = None |
|
countPenalty: Optional[dict] = None |
|
|
|
def __init__( |
|
self, |
|
maxTokens: Optional[int] = None, |
|
temperature: Optional[float] = None, |
|
topP: Optional[float] = None, |
|
stopSequences: Optional[list] = None, |
|
frequencePenalty: Optional[dict] = None, |
|
presencePenalty: Optional[dict] = None, |
|
countPenalty: Optional[dict] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return { |
|
k: v |
|
for k, v in cls.__dict__.items() |
|
if not k.startswith("__") |
|
and not isinstance( |
|
v, |
|
( |
|
types.FunctionType, |
|
types.BuiltinFunctionType, |
|
classmethod, |
|
staticmethod, |
|
), |
|
) |
|
and v is not None |
|
} |
|
|
|
|
|
class AnthropicConstants(Enum): |
|
HUMAN_PROMPT = "\n\nHuman: " |
|
AI_PROMPT = "\n\nAssistant: " |
|
|
|
|
|
class AmazonLlamaConfig: |
|
""" |
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1 |
|
|
|
Supported Params for the Amazon / Meta Llama models: |
|
|
|
- `max_gen_len` (integer) max tokens, |
|
- `temperature` (float) temperature for model, |
|
- `top_p` (float) top p for model |
|
""" |
|
|
|
max_gen_len: Optional[int] = None |
|
temperature: Optional[float] = None |
|
topP: Optional[float] = None |
|
|
|
def __init__( |
|
self, |
|
maxTokenCount: Optional[int] = None, |
|
temperature: Optional[float] = None, |
|
topP: Optional[int] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return { |
|
k: v |
|
for k, v in cls.__dict__.items() |
|
if not k.startswith("__") |
|
and not isinstance( |
|
v, |
|
( |
|
types.FunctionType, |
|
types.BuiltinFunctionType, |
|
classmethod, |
|
staticmethod, |
|
), |
|
) |
|
and v is not None |
|
} |
|
|
|
|
|
def init_bedrock_client( |
|
region_name=None, |
|
aws_access_key_id: Optional[str] = None, |
|
aws_secret_access_key: Optional[str] = None, |
|
aws_region_name: Optional[str] = None, |
|
aws_bedrock_runtime_endpoint: Optional[str] = None, |
|
): |
|
|
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) |
|
standard_aws_region_name = get_secret("AWS_REGION", None) |
|
|
|
|
|
|
|
params_to_check = [ |
|
aws_access_key_id, |
|
aws_secret_access_key, |
|
aws_region_name, |
|
aws_bedrock_runtime_endpoint, |
|
] |
|
|
|
|
|
for i, param in enumerate(params_to_check): |
|
if param and param.startswith("os.environ/"): |
|
params_to_check[i] = get_secret(param) |
|
|
|
( |
|
aws_access_key_id, |
|
aws_secret_access_key, |
|
aws_region_name, |
|
aws_bedrock_runtime_endpoint, |
|
) = params_to_check |
|
if region_name: |
|
pass |
|
elif aws_region_name: |
|
region_name = aws_region_name |
|
elif litellm_aws_region_name: |
|
region_name = litellm_aws_region_name |
|
elif standard_aws_region_name: |
|
region_name = standard_aws_region_name |
|
else: |
|
raise BedrockError( |
|
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", |
|
status_code=401, |
|
) |
|
|
|
|
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") |
|
if aws_bedrock_runtime_endpoint: |
|
endpoint_url = aws_bedrock_runtime_endpoint |
|
elif env_aws_bedrock_runtime_endpoint: |
|
endpoint_url = env_aws_bedrock_runtime_endpoint |
|
else: |
|
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com" |
|
|
|
import boto3 |
|
|
|
if aws_access_key_id != None: |
|
|
|
|
|
|
|
client = boto3.client( |
|
service_name="bedrock-runtime", |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
region_name=region_name, |
|
endpoint_url=endpoint_url, |
|
) |
|
else: |
|
|
|
|
|
|
|
client = boto3.client( |
|
service_name="bedrock-runtime", |
|
region_name=region_name, |
|
endpoint_url=endpoint_url, |
|
) |
|
|
|
return client |
|
|
|
|
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): |
|
|
|
if provider == "anthropic": |
|
if model in custom_prompt_dict: |
|
|
|
model_prompt_details = custom_prompt_dict[model] |
|
prompt = custom_prompt( |
|
role_dict=model_prompt_details["roles"], |
|
initial_prompt_value=model_prompt_details["initial_prompt_value"], |
|
final_prompt_value=model_prompt_details["final_prompt_value"], |
|
messages=messages, |
|
) |
|
else: |
|
prompt = prompt_factory( |
|
model=model, messages=messages, custom_llm_provider="anthropic" |
|
) |
|
else: |
|
prompt = "" |
|
for message in messages: |
|
if "role" in message: |
|
if message["role"] == "user": |
|
prompt += f"{message['content']}" |
|
else: |
|
prompt += f"{message['content']}" |
|
else: |
|
prompt += f"{message['content']}" |
|
return prompt |
|
|
|
|
|
""" |
|
BEDROCK AUTH Keys/Vars |
|
os.environ['AWS_ACCESS_KEY_ID'] = "" |
|
os.environ['AWS_SECRET_ACCESS_KEY'] = "" |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def completion( |
|
model: str, |
|
messages: list, |
|
custom_prompt_dict: dict, |
|
model_response: ModelResponse, |
|
print_verbose: Callable, |
|
encoding, |
|
logging_obj, |
|
optional_params=None, |
|
litellm_params=None, |
|
logger_fn=None, |
|
): |
|
exception_mapping_worked = False |
|
try: |
|
|
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) |
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None) |
|
aws_region_name = optional_params.pop("aws_region_name", None) |
|
aws_bedrock_runtime_endpoint = optional_params.pop( |
|
"aws_bedrock_runtime_endpoint", None |
|
) |
|
|
|
|
|
client = optional_params.pop("aws_bedrock_client", None) |
|
|
|
|
|
if client is None: |
|
client = init_bedrock_client( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
aws_region_name=aws_region_name, |
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, |
|
) |
|
|
|
model = model |
|
modelId = ( |
|
optional_params.pop("model_id", None) or model |
|
) |
|
provider = model.split(".")[0] |
|
prompt = convert_messages_to_prompt( |
|
model, messages, provider, custom_prompt_dict |
|
) |
|
inference_params = copy.deepcopy(optional_params) |
|
stream = inference_params.pop("stream", False) |
|
if provider == "anthropic": |
|
|
|
config = litellm.AmazonAnthropicConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "ai21": |
|
|
|
config = litellm.AmazonAI21Config.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
|
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "cohere": |
|
|
|
config = litellm.AmazonCohereConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
if optional_params.get("stream", False) == True: |
|
inference_params[ |
|
"stream" |
|
] = True |
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "meta": |
|
|
|
config = litellm.AmazonLlamaConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "amazon": |
|
|
|
config = litellm.AmazonTitanConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
|
|
data = json.dumps( |
|
{ |
|
"inputText": prompt, |
|
"textGenerationConfig": inference_params, |
|
} |
|
) |
|
else: |
|
data = json.dumps({}) |
|
|
|
|
|
accept = "application/json" |
|
contentType = "application/json" |
|
if stream == True: |
|
if provider == "ai21": |
|
|
|
request_str = f""" |
|
response = client.invoke_model( |
|
body={data}, |
|
modelId={modelId}, |
|
accept=accept, |
|
contentType=contentType |
|
) |
|
""" |
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": data, |
|
"request_str": request_str, |
|
}, |
|
) |
|
|
|
response = client.invoke_model( |
|
body=data, modelId=modelId, accept=accept, contentType=contentType |
|
) |
|
|
|
response = response.get("body").read() |
|
return response |
|
else: |
|
|
|
request_str = f""" |
|
response = client.invoke_model_with_response_stream( |
|
body={data}, |
|
modelId={modelId}, |
|
accept=accept, |
|
contentType=contentType |
|
) |
|
""" |
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": data, |
|
"request_str": request_str, |
|
}, |
|
) |
|
|
|
response = client.invoke_model_with_response_stream( |
|
body=data, modelId=modelId, accept=accept, contentType=contentType |
|
) |
|
response = response.get("body") |
|
return response |
|
try: |
|
|
|
request_str = f""" |
|
response = client.invoke_model( |
|
body={data}, |
|
modelId={modelId}, |
|
accept=accept, |
|
contentType=contentType |
|
) |
|
""" |
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": data, |
|
"request_str": request_str, |
|
}, |
|
) |
|
response = client.invoke_model( |
|
body=data, modelId=modelId, accept=accept, contentType=contentType |
|
) |
|
except client.exceptions.ValidationException as e: |
|
if "The provided model identifier is invalid" in str(e): |
|
raise BedrockError(status_code=404, message=str(e)) |
|
raise BedrockError(status_code=400, message=str(e)) |
|
except Exception as e: |
|
raise BedrockError(status_code=500, message=str(e)) |
|
|
|
response_body = json.loads(response.get("body").read()) |
|
|
|
|
|
logging_obj.post_call( |
|
input=prompt, |
|
api_key="", |
|
original_response=json.dumps(response_body), |
|
additional_args={"complete_input_dict": data}, |
|
) |
|
print_verbose(f"raw model_response: {response}") |
|
|
|
outputText = "default" |
|
if provider == "ai21": |
|
outputText = response_body.get("completions")[0].get("data").get("text") |
|
elif provider == "anthropic": |
|
outputText = response_body["completion"] |
|
model_response["finish_reason"] = response_body["stop_reason"] |
|
elif provider == "cohere": |
|
outputText = response_body["generations"][0]["text"] |
|
elif provider == "meta": |
|
outputText = response_body["generation"] |
|
else: |
|
outputText = response_body.get("results")[0].get("outputText") |
|
|
|
response_metadata = response.get("ResponseMetadata", {}) |
|
if response_metadata.get("HTTPStatusCode", 500) >= 400: |
|
raise BedrockError( |
|
message=outputText, |
|
status_code=response_metadata.get("HTTPStatusCode", 500), |
|
) |
|
else: |
|
try: |
|
if len(outputText) > 0: |
|
model_response["choices"][0]["message"]["content"] = outputText |
|
except: |
|
raise BedrockError( |
|
message=json.dumps(outputText), |
|
status_code=response_metadata.get("HTTPStatusCode", 500), |
|
) |
|
|
|
|
|
prompt_tokens = len(encoding.encode(prompt)) |
|
completion_tokens = len( |
|
encoding.encode(model_response["choices"][0]["message"].get("content", "")) |
|
) |
|
|
|
model_response["created"] = int(time.time()) |
|
model_response["model"] = model |
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens, |
|
) |
|
model_response.usage = usage |
|
return model_response |
|
except BedrockError as e: |
|
exception_mapping_worked = True |
|
raise e |
|
except Exception as e: |
|
if exception_mapping_worked: |
|
raise e |
|
else: |
|
import traceback |
|
|
|
raise BedrockError(status_code=500, message=traceback.format_exc()) |
|
|
|
|
|
def _embedding_func_single( |
|
model: str, |
|
input: str, |
|
client: Any, |
|
optional_params=None, |
|
encoding=None, |
|
logging_obj=None, |
|
): |
|
|
|
|
|
provider = model.split(".")[0] |
|
inference_params = copy.deepcopy(optional_params) |
|
inference_params.pop( |
|
"user", None |
|
) |
|
modelId = ( |
|
optional_params.pop("model_id", None) or model |
|
) |
|
if provider == "amazon": |
|
input = input.replace(os.linesep, " ") |
|
data = {"inputText": input, **inference_params} |
|
|
|
elif provider == "cohere": |
|
inference_params["input_type"] = inference_params.get( |
|
"input_type", "search_document" |
|
) |
|
data = {"texts": [input], **inference_params} |
|
body = json.dumps(data).encode("utf-8") |
|
|
|
request_str = f""" |
|
response = client.invoke_model( |
|
body={body}, |
|
modelId={modelId}, |
|
accept="*/*", |
|
contentType="application/json", |
|
)""" |
|
logging_obj.pre_call( |
|
input=input, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": {"model": modelId, "texts": input}, |
|
"request_str": request_str, |
|
}, |
|
) |
|
try: |
|
response = client.invoke_model( |
|
body=body, |
|
modelId=modelId, |
|
accept="*/*", |
|
contentType="application/json", |
|
) |
|
response_body = json.loads(response.get("body").read()) |
|
|
|
logging_obj.post_call( |
|
input=input, |
|
api_key="", |
|
additional_args={"complete_input_dict": data}, |
|
original_response=json.dumps(response_body), |
|
) |
|
if provider == "cohere": |
|
response = response_body.get("embeddings") |
|
|
|
response = [item for sublist in response for item in sublist] |
|
return response |
|
elif provider == "amazon": |
|
return response_body.get("embedding") |
|
except Exception as e: |
|
raise BedrockError( |
|
message=f"Embedding Error with model {model}: {e}", status_code=500 |
|
) |
|
|
|
|
|
def embedding( |
|
model: str, |
|
input: Union[list, str], |
|
api_key: Optional[str] = None, |
|
logging_obj=None, |
|
model_response=None, |
|
optional_params=None, |
|
encoding=None, |
|
): |
|
|
|
|
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) |
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None) |
|
aws_region_name = optional_params.pop("aws_region_name", None) |
|
aws_bedrock_runtime_endpoint = optional_params.pop( |
|
"aws_bedrock_runtime_endpoint", None |
|
) |
|
|
|
|
|
client = init_bedrock_client( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
aws_region_name=aws_region_name, |
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, |
|
) |
|
if type(input) == str: |
|
embeddings = [ |
|
_embedding_func_single( |
|
model, |
|
input, |
|
optional_params=optional_params, |
|
client=client, |
|
logging_obj=logging_obj, |
|
) |
|
] |
|
else: |
|
|
|
embeddings = [ |
|
_embedding_func_single( |
|
model, |
|
i, |
|
optional_params=optional_params, |
|
client=client, |
|
logging_obj=logging_obj, |
|
) |
|
for i in input |
|
] |
|
|
|
|
|
embedding_response = [] |
|
for idx, embedding in enumerate(embeddings): |
|
embedding_response.append( |
|
{ |
|
"object": "embedding", |
|
"index": idx, |
|
"embedding": embedding, |
|
} |
|
) |
|
model_response["object"] = "list" |
|
model_response["data"] = embedding_response |
|
model_response["model"] = model |
|
input_tokens = 0 |
|
|
|
input_str = "".join(input) |
|
|
|
input_tokens += len(encoding.encode(input_str)) |
|
|
|
usage = Usage( |
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0 |
|
) |
|
model_response.usage = usage |
|
|
|
return model_response |
|
|