Spaces:
Running
Running
File size: 4,345 Bytes
b850722 096a26c b207b4c 096a26c b207b4c b850722 096a26c b207b4c 096a26c b207b4c 096a26c b207b4c 096a26c b207b4c 096a26c b850722 096a26c b207b4c 096a26c b207b4c 096a26c b207b4c 096a26c b207b4c 096a26c b207b4c 096a26c b850722 67dbb33 778809b 67dbb33 b850722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from typing import Optional, Union
import weave
from openai import OpenAI
from openai.types.chat import ChatCompletion
class OpenAIModel(weave.Model):
"""
A class to interface with OpenAI's language models using the Weave framework.
This class provides methods to create structured messages and generate predictions
using OpenAI's chat completion API. It is designed to work with both single and
multiple user prompts, and optionally includes a system prompt to guide the model's
responses.
Args:
model_name (str): The name of the OpenAI model to be used for predictions.
"""
model_name: str
_openai_client: OpenAI
def __init__(self, model_name: str = "gpt-4o") -> None:
super().__init__(model_name=model_name)
self._openai_client = OpenAI()
@weave.op()
def create_messages(
self,
user_prompts: Union[str, list[str]],
system_prompt: Optional[str] = None,
messages: Optional[list[dict]] = None,
) -> list[dict]:
"""
Create a list of messages for the OpenAI chat completion API.
This function constructs a list of messages in the format required by the
OpenAI chat completion API. It takes user prompts, an optional system prompt,
and an optional list of existing messages, and combines them into a single
list of messages.
Args:
user_prompts (Union[str, list[str]]): A single user prompt or a list of
user prompts to be included in the messages.
system_prompt (Optional[str]): An optional system prompt to guide the
model's responses. If provided, it will be added at the beginning
of the messages list.
messages (Optional[list[dict]]): An optional list of existing messages
to which the new prompts will be appended. If not provided, a new
list will be created.
Returns:
list[dict]: A list of messages formatted for the OpenAI chat completion API.
"""
user_prompts = [user_prompts] if isinstance(user_prompts, str) else user_prompts
messages = list(messages) if isinstance(messages, dict) else []
for user_prompt in user_prompts:
messages.append({"role": "user", "content": user_prompt})
if system_prompt is not None:
messages = [{"role": "system", "content": system_prompt}] + messages
return messages
@weave.op()
def predict(
self,
user_prompts: Union[str, list[str]],
system_prompt: Optional[str] = None,
messages: Optional[list[dict]] = None,
**kwargs,
) -> ChatCompletion:
"""
Generate a chat completion response using the OpenAI API.
This function takes user prompts, an optional system prompt, and an optional
list of existing messages to create a list of messages formatted for the
OpenAI chat completion API. It then sends these messages to the OpenAI API
to generate a chat completion response.
Args:
user_prompts (Union[str, list[str]]): A single user prompt or a list of
user prompts to be included in the messages.
system_prompt (Optional[str]): An optional system prompt to guide the
model's responses. If provided, it will be added at the beginning
of the messages list.
messages (Optional[list[dict]]): An optional list of existing messages
to which the new prompts will be appended. If not provided, a new
list will be created.
**kwargs: Additional keyword arguments to be passed to the OpenAI API
for chat completion.
Returns:
ChatCompletion: The chat completion response from the OpenAI API.
"""
messages = self.create_messages(user_prompts, system_prompt, messages)
if "response_format" in kwargs:
response = self._openai_client.beta.chat.completions.parse(
model=self.model_name, messages=messages, **kwargs
)
else:
response = self._openai_client.chat.completions.create(
model=self.model_name, messages=messages, **kwargs
)
return response
|