import os import math from typing import Union, Optional import torch import logging #from vllm import LLM, SamplingParams #from vllm.lora.request import LoRARequest from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed, BitsAndBytesConfig import openai from openai.error import (APIError, RateLimitError, ServiceUnavailableError, Timeout, APIConnectionError, InvalidRequestError) from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_delay, wait_random_exponential, stop_after_attempt) logger = logging.getLogger(__name__) class Summarizer: def __init__(self, inference_mode:str, model_id:str, api_key:str, dtype:str="bfloat16", seed=42, context_size:int=int(1024*26), gpu_memory_utilization:int=0.7, tensor_parallel_size=1 ) -> None: self.inference_mode=inference_mode self.model = None self.tokenizer = None self.seed = seed openai.api_key = api_key self.model = model_id def get_generation_config( self, repetition_penalty:float = 1.2, do_sample:bool=True, temperature:float = 0.1, top_p:float = 0.9, max_tokens:int = 1024 ): return generation_config @retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError, ServiceUnavailableError, APIConnectionError, InvalidRequestError)), wait=wait_random_exponential(max=60), stop=stop_after_attempt(10), before_sleep=before_sleep_log(logger, logging.WARNING)) def inference_with_gpt(self, prompt): prompt_messages = [{"role": "user", "content": prompt}] try: response = openai.ChatCompletion.create(model = self.model, messages = prompt_messages, temperature = 0.1) #finish_reason = response.choices[0].finish_reason response = response.choices[0].message.content except InvalidRequestError: response = '' return response