Spaces:
Running
Running
| import math | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional | |
| import openai | |
| from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError | |
| from tenacity import ( | |
| retry, | |
| retry_if_exception_type, | |
| stop_after_attempt, | |
| wait_exponential, | |
| ) | |
| from graphgen.models.llm.limitter import RPM, TPM | |
| from graphgen.models.llm.tokenizer import Tokenizer | |
| from graphgen.models.llm.topk_token_model import Token, TopkTokenModel | |
| def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: | |
| token_logprobs = response.choices[0].logprobs.content | |
| tokens = [] | |
| for token_prob in token_logprobs: | |
| prob = math.exp(token_prob.logprob) | |
| candidate_tokens = [ | |
| Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs | |
| ] | |
| token = Token(token_prob.token, prob, top_candidates=candidate_tokens) | |
| tokens.append(token) | |
| return tokens | |
| def filter_think_tags(text: str) -> str: | |
| """ | |
| Remove <think> tags from the text. | |
| If the text contains <think> and </think>, it removes everything between them and the tags themselves. | |
| """ | |
| think_pattern = re.compile(r"<think>.*?</think>", re.DOTALL) | |
| filtered_text = think_pattern.sub("", text).strip() | |
| return filtered_text if filtered_text else text.strip() | |
| class OpenAIModel(TopkTokenModel): | |
| model_name: str = "gpt-4o-mini" | |
| api_key: str = None | |
| base_url: str = None | |
| system_prompt: str = "" | |
| json_mode: bool = False | |
| seed: int = None | |
| token_usage: list = field(default_factory=list) | |
| request_limit: bool = False | |
| rpm: RPM = field(default_factory=lambda: RPM(rpm=1000)) | |
| tpm: TPM = field(default_factory=lambda: TPM(tpm=50000)) | |
| tokenizer_instance: Tokenizer = field(default_factory=Tokenizer) | |
| def __post_init__(self): | |
| assert self.api_key is not None, "Please provide api key to access openai api." | |
| self.client = AsyncOpenAI( | |
| api_key=self.api_key or "dummy", base_url=self.base_url | |
| ) | |
| def _pre_generate(self, text: str, history: List[str]) -> Dict: | |
| kwargs = { | |
| "temperature": self.temperature, | |
| "top_p": self.topp, | |
| "max_tokens": self.max_tokens, | |
| } | |
| if self.seed: | |
| kwargs["seed"] = self.seed | |
| if self.json_mode: | |
| kwargs["response_format"] = {"type": "json_object"} | |
| messages = [] | |
| if self.system_prompt: | |
| messages.append({"role": "system", "content": self.system_prompt}) | |
| messages.append({"role": "user", "content": text}) | |
| if history: | |
| assert len(history) % 2 == 0, "History should have even number of elements." | |
| messages = history + messages | |
| kwargs["messages"] = messages | |
| return kwargs | |
| async def generate_topk_per_token( | |
| self, text: str, history: Optional[List[str]] = None | |
| ) -> List[Token]: | |
| kwargs = self._pre_generate(text, history) | |
| if self.topk_per_token > 0: | |
| kwargs["logprobs"] = True | |
| kwargs["top_logprobs"] = self.topk_per_token | |
| # Limit max_tokens to 1 to avoid long completions | |
| kwargs["max_tokens"] = 1 | |
| completion = await self.client.chat.completions.create( # pylint: disable=E1125 | |
| model=self.model_name, **kwargs | |
| ) | |
| tokens = get_top_response_tokens(completion) | |
| return tokens | |
| async def generate_answer( | |
| self, text: str, history: Optional[List[str]] = None, temperature: int = 0 | |
| ) -> str: | |
| kwargs = self._pre_generate(text, history) | |
| kwargs["temperature"] = temperature | |
| prompt_tokens = 0 | |
| for message in kwargs["messages"]: | |
| prompt_tokens += len( | |
| self.tokenizer_instance.encode_string(message["content"]) | |
| ) | |
| estimated_tokens = prompt_tokens + kwargs["max_tokens"] | |
| if self.request_limit: | |
| await self.rpm.wait(silent=True) | |
| await self.tpm.wait(estimated_tokens, silent=True) | |
| completion = await self.client.chat.completions.create( # pylint: disable=E1125 | |
| model=self.model_name, **kwargs | |
| ) | |
| if hasattr(completion, "usage"): | |
| self.token_usage.append( | |
| { | |
| "prompt_tokens": completion.usage.prompt_tokens, | |
| "completion_tokens": completion.usage.completion_tokens, | |
| "total_tokens": completion.usage.total_tokens, | |
| } | |
| ) | |
| return filter_think_tags(completion.choices[0].message.content) | |
| async def generate_inputs_prob( | |
| self, text: str, history: Optional[List[str]] = None | |
| ) -> List[Token]: | |
| raise NotImplementedError | |