|
from __future__ import annotations |
|
|
|
from typing import Any |
|
import json |
|
from dsp.modules.aws_lm import AWSLM |
|
|
|
|
|
class Bedrock(AWSLM): |
|
def __init__( |
|
self, |
|
region_name: str, |
|
model: str, |
|
input_output_ratio: int = 3, |
|
max_new_tokens: int = 1500, |
|
) -> None: |
|
"""Use an AWS Bedrock language model. |
|
NOTE: You must first configure your AWS credentials with the AWS CLI before using this model! |
|
|
|
Args: |
|
region_name (str, optional): The AWS region where this LM is hosted. |
|
model (str, optional): An AWS Bedrock LM name. You can find available models with the AWS CLI as follows: aws bedrock list-foundation-models --query "modelSummaries[*].modelId". |
|
temperature (float, optional): Default temperature for LM. Defaults to 0. |
|
input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. |
|
max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. |
|
""" |
|
super().__init__( |
|
model=model, |
|
service_name="bedrock-runtime", |
|
region_name=region_name, |
|
truncate_long_prompts=False, |
|
input_output_ratio=input_output_ratio, |
|
max_new_tokens=max_new_tokens, |
|
batch_n=True, |
|
) |
|
self._validate_model(model) |
|
|
|
def _validate_model(self, model: str) -> None: |
|
if "claude" not in model.lower(): |
|
raise NotImplementedError("Only claude models are supported as of now") |
|
|
|
def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]: |
|
base_args: dict[str, Any] = { |
|
"max_tokens_to_sample": self._max_new_tokens, |
|
} |
|
for k, v in kwargs.items(): |
|
base_args[k] = v |
|
query_args: dict[str, Any] = self._sanitize_kwargs(base_args) |
|
query_args["prompt"] = prompt |
|
|
|
if "max_tokens" in query_args.keys(): |
|
max_tokens: int = query_args["max_tokens"] |
|
input_tokens: int = self._estimate_tokens(prompt) |
|
max_tokens_to_sample: int = max_tokens - input_tokens |
|
del query_args["max_tokens"] |
|
query_args["max_tokens_to_sample"] = max_tokens_to_sample |
|
return query_args |
|
|
|
def _call_model(self, body: str) -> str: |
|
response = self.predictor.invoke_model( |
|
modelId=self._model_name, |
|
body=body, |
|
accept="application/json", |
|
contentType="application/json", |
|
) |
|
response_body = json.loads(response["body"].read()) |
|
completion = response_body["completion"] |
|
return completion |
|
|
|
def _extract_input_parameters( |
|
self, body: dict[Any, Any] |
|
) -> dict[str, str | float | int]: |
|
return body |
|
|
|
def _format_prompt(self, raw_prompt: str) -> str: |
|
return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" |
|
|