EureCA / dsp /modules /bedrock.py
tonneli's picture
Delete history
f5776d3
from __future__ import annotations
from typing import Any
import json
from dsp.modules.aws_lm import AWSLM
class Bedrock(AWSLM):
def __init__(
self,
region_name: str,
model: str,
input_output_ratio: int = 3,
max_new_tokens: int = 1500,
) -> None:
"""Use an AWS Bedrock language model.
NOTE: You must first configure your AWS credentials with the AWS CLI before using this model!
Args:
region_name (str, optional): The AWS region where this LM is hosted.
model (str, optional): An AWS Bedrock LM name. You can find available models with the AWS CLI as follows: aws bedrock list-foundation-models --query "modelSummaries[*].modelId".
temperature (float, optional): Default temperature for LM. Defaults to 0.
input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3.
max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM.
"""
super().__init__(
model=model,
service_name="bedrock-runtime",
region_name=region_name,
truncate_long_prompts=False,
input_output_ratio=input_output_ratio,
max_new_tokens=max_new_tokens,
batch_n=True, # Bedrock does not support the `n` parameter
)
self._validate_model(model)
def _validate_model(self, model: str) -> None:
if "claude" not in model.lower():
raise NotImplementedError("Only claude models are supported as of now")
def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]:
base_args: dict[str, Any] = {
"max_tokens_to_sample": self._max_new_tokens,
}
for k, v in kwargs.items():
base_args[k] = v
query_args: dict[str, Any] = self._sanitize_kwargs(base_args)
query_args["prompt"] = prompt
# AWS Bedrock forbids these keys
if "max_tokens" in query_args.keys():
max_tokens: int = query_args["max_tokens"]
input_tokens: int = self._estimate_tokens(prompt)
max_tokens_to_sample: int = max_tokens - input_tokens
del query_args["max_tokens"]
query_args["max_tokens_to_sample"] = max_tokens_to_sample
return query_args
def _call_model(self, body: str) -> str:
response = self.predictor.invoke_model(
modelId=self._model_name,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response["body"].read())
completion = response_body["completion"]
return completion
def _extract_input_parameters(
self, body: dict[Any, Any]
) -> dict[str, str | float | int]:
return body
def _format_prompt(self, raw_prompt: str) -> str:
return "\n\nHuman: " + raw_prompt + "\n\nAssistant:"