Spaces:
Running
Running
File size: 4,382 Bytes
2a0aa5a a4e8fcb 43c8549 2a0aa5a 9e789e7 2a0aa5a a4e8fcb 2a0aa5a a4e8fcb 9e789e7 a4e8fcb 9e789e7 a4e8fcb 9e789e7 2a0aa5a a4e8fcb 2a0aa5a a4e8fcb 2a0aa5a a4e8fcb 2a0aa5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""
This module contains functions to interact with the models.
"""
import json
from typing import List
import litellm
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the given text without changing the language of it." # pylint: disable=line-too-long
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the given text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
class ContextWindowExceededError(Exception):
pass
class Model:
def __init__(
self,
name: str,
provider: str = None,
api_key: str = None,
api_base: str = None,
summarize_instruction: str = None,
translate_instruction: str = None,
):
self.name = name
self.provider = provider
self.api_key = api_key
self.api_base = api_base
self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
def completion(self,
instruction: str,
prompt: str,
max_tokens: float = None) -> str:
messages = [{
"role":
"system",
"content":
instruction + """
Output following this JSON format:
{"result": "your result here"}"""
}, {
"role": "user",
"content": prompt
}]
try:
response = litellm.completion(
model=self.provider + "/" + self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens,
# Ref: https://litellm.vercel.app/docs/completion/input#optional-fields # pylint: disable=line-too-long
response_format={"type": "json_object"})
json_response = response.choices[0].message.content
parsed_json = json.loads(json_response)
return parsed_json["result"]
except litellm.ContextWindowExceededError as e:
raise ContextWindowExceededError() from e
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to get JSON response: {e}") from e
class AnthropicModel(Model):
def completion(self,
instruction: str,
prompt: str,
max_tokens: float = None) -> str:
# Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long
prefix = "<result>"
suffix = "</result>"
messages = [{
"role":
"user",
"content":
f"""{instruction}
Output following this format:
{prefix}...{suffix}
Text:
{prompt}"""
}, {
"role": "assistant",
"content": prefix
}]
try:
response = litellm.completion(
model=self.provider + "/" + self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens,
)
except litellm.ContextWindowExceededError as e:
raise ContextWindowExceededError() from e
result = response.choices[0].message.content
if not result.endswith(suffix):
raise RuntimeError(f"Failed to get the formatted response: {result}")
return result.removesuffix(suffix).strip()
supported_models: List[Model] = [
Model("gpt-4o-2024-05-13"),
Model("gpt-4-turbo-2024-04-09"),
Model("gpt-4-0125-preview"),
Model("gpt-3.5-turbo-0125"),
AnthropicModel("claude-3-opus-20240229"),
AnthropicModel("claude-3-sonnet-20240229"),
AnthropicModel("claude-3-haiku-20240307"),
Model("mistral-small-2402", provider="mistral"),
Model("mistral-large-2402", provider="mistral"),
Model("llama3-8b-8192", provider="groq"),
Model("llama3-70b-8192", provider="groq"),
]
def check_models(models: List[Model]):
for model in models:
print(f"Checking model {model.name}...")
try:
model.completion("You are an AI model.", "Hello, world!")
print(f"Model {model.name} is available.")
# This check is designed to verify the availability of the models
# without any issues. Therefore, we need to catch all exceptions.
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(f"Model {model.name} is not available: {e}") from e
|