|
import logging |
|
from .base_model import BaseLLMModel |
|
from .. import shared |
|
import requests |
|
from ..presets import * |
|
from ..config import retrieve_proxy, sensitive_id |
|
|
|
class OpenAI_DALLE3_Client(BaseLLMModel): |
|
def __init__(self, model_name, api_key, user_name="") -> None: |
|
super().__init__(model_name=model_name, user=user_name) |
|
self.api_key = api_key |
|
self._refresh_header() |
|
|
|
def _get_dalle3_prompt(self): |
|
prompt = self.history[-1]["content"] |
|
if prompt.endswith("--raw"): |
|
prompt = "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:" + prompt |
|
return prompt |
|
|
|
def get_answer_at_once(self, stream=False): |
|
prompt = self._get_dalle3_prompt() |
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {self.api_key}" |
|
} |
|
payload = { |
|
"model": "dall-e-3", |
|
"prompt": prompt, |
|
"n": 1, |
|
"size": "1024x1024", |
|
"quality": "standard", |
|
} |
|
if stream: |
|
timeout = TIMEOUT_STREAMING |
|
else: |
|
timeout = TIMEOUT_ALL |
|
|
|
if shared.state.images_completion_url != IMAGES_COMPLETION_URL: |
|
logging.debug(f"使用自定义API URL: {shared.state.images_completion_url}") |
|
|
|
with retrieve_proxy(): |
|
try: |
|
response = requests.post( |
|
shared.state.images_completion_url, |
|
headers=headers, |
|
json=payload, |
|
stream=stream, |
|
timeout=timeout, |
|
) |
|
response.raise_for_status() |
|
response_data = response.json() |
|
image_url = response_data['data'][0]['url'] |
|
img_tag = f'<!-- S O PREFIX --><a data-fancybox="gallery" target="_blank" href="{image_url}"><img src="{image_url}" /></a><!-- E O PREFIX -->' |
|
revised_prompt = response_data['data'][0].get('revised_prompt', '') |
|
return img_tag + revised_prompt, 0 |
|
except requests.exceptions.RequestException as e: |
|
return str(e), 0 |
|
|
|
def _refresh_header(self): |
|
self.headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {sensitive_id}", |
|
} |