GraphGen / graphgen /models /llm /openai_model.py
chenzihong-gavin
init
acd7cf4
import math
from dataclasses import dataclass, field
from typing import List, Dict, Optional
import openai
from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from graphgen.models.llm.topk_token_model import TopkTokenModel, Token
from graphgen.models.llm.tokenizer import Tokenizer
from graphgen.models.llm.limitter import RPM, TPM
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
token_logprobs = response.choices[0].logprobs.content
tokens = []
for token_prob in token_logprobs:
prob = math.exp(token_prob.logprob)
candidate_tokens = [
Token(t.token, math.exp(t.logprob))
for t in token_prob.top_logprobs
]
token = Token(token_prob.token, prob, top_candidates=candidate_tokens)
tokens.append(token)
return tokens
@dataclass
class OpenAIModel(TopkTokenModel):
model_name: str = "gpt-4o-mini"
api_key: str = None
base_url: str = None
system_prompt: str = ""
json_mode: bool = False
seed: int = None
token_usage: list = field(default_factory=list)
request_limit: bool = False
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
def __post_init__(self):
assert self.api_key is not None, "Please provide api key to access openai api."
if self.api_key == "":
self.api_key = "none"
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
def _pre_generate(self, text: str, history: List[str]) -> Dict:
kwargs = {
"temperature": self.temperature,
"top_p": self.topp,
"max_tokens": self.max_tokens,
}
if self.seed:
kwargs["seed"] = self.seed
if self.json_mode:
kwargs["response_format"] = {"type": "json_object"}
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": text})
if history:
assert len(history) % 2 == 0, "History should have even number of elements."
messages = history + messages
kwargs['messages']= messages
return kwargs
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def generate_topk_per_token(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
kwargs = self._pre_generate(text, history)
if self.topk_per_token > 0:
kwargs["logprobs"] = True
kwargs["top_logprobs"] = self.topk_per_token
# Limit max_tokens to 1 to avoid long completions
kwargs["max_tokens"] = 1
completion = await self.client.chat.completions.create( # pylint: disable=E1125
model=self.model_name,
**kwargs
)
tokens = get_top_response_tokens(completion)
return tokens
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def generate_answer(self, text: str, history: Optional[List[str]] = None, temperature: int = 0) -> str:
kwargs = self._pre_generate(text, history)
kwargs["temperature"] = temperature
prompt_tokens = 0
for message in kwargs['messages']:
prompt_tokens += len(Tokenizer().encode_string(message['content']))
estimated_tokens = prompt_tokens + kwargs['max_tokens']
if self.request_limit:
await self.rpm.wait(silent=True)
await self.tpm.wait(estimated_tokens, silent=True)
completion = await self.client.chat.completions.create( # pylint: disable=E1125
model=self.model_name,
**kwargs
)
if hasattr(completion, "usage"):
self.token_usage.append({
"prompt_tokens": completion.usage.prompt_tokens,
"completion_tokens": completion.usage.completion_tokens,
"total_tokens": completion.usage.total_tokens,
})
return completion.choices[0].message.content
async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
raise NotImplementedError