|
|
from tenacity import retry, stop_after_delay, wait_fixed, retry_if_result |
|
|
import requests |
|
|
import json |
|
|
import os |
|
|
import logging |
|
|
|
|
|
|
|
|
class MTOpenApiClient: |
|
|
|
|
|
def __init__(self, api_name, api_key=None, cost_attribution=None): |
|
|
""" |
|
|
Initialize MTOpenApiClient with credentials. |
|
|
|
|
|
Args: |
|
|
api_name (str): API endpoint name |
|
|
api_key (str, optional): Direct API key for credential lookup |
|
|
cost_attribution (str, optional): Cost attribution key for AI flow credentials |
|
|
|
|
|
Raises: |
|
|
ValueError: If credentials are invalid or missing |
|
|
FileNotFoundError: If configuration files are not found |
|
|
""" |
|
|
self.api_name = api_name |
|
|
|
|
|
|
|
|
if api_key is not None: |
|
|
self._load_credentials_by_api_key(api_key) |
|
|
elif cost_attribution is not None: |
|
|
self._load_credentials_by_cost_attribution(cost_attribution) |
|
|
else: |
|
|
raise ValueError("Either api_key or cost_attribution must be provided") |
|
|
|
|
|
|
|
|
self._initialize_urls() |
|
|
|
|
|
def _get_config_path(self, filename): |
|
|
"""Get the full path to a configuration file.""" |
|
|
dir_path = os.path.dirname(os.path.abspath(__file__)) |
|
|
return os.path.join(dir_path, 'config', filename) |
|
|
|
|
|
def _load_json_config(self, filename): |
|
|
"""Load and parse a JSON configuration file.""" |
|
|
config_path = self._get_config_path(filename) |
|
|
try: |
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
raise FileNotFoundError(f"Configuration file not found: {config_path}") |
|
|
except json.JSONDecodeError as e: |
|
|
raise ValueError(f"Invalid JSON in configuration file {filename}: {e}") |
|
|
|
|
|
def _validate_credentials(self): |
|
|
"""Validate that required credentials are present and non-empty.""" |
|
|
if not hasattr(self, 'api_key') or not self.api_key: |
|
|
raise ValueError("api_key is missing or empty") |
|
|
if not hasattr(self, 'api_secret') or not self.api_secret: |
|
|
raise ValueError("api_secret is missing or empty") |
|
|
|
|
|
def _load_credentials_by_api_key(self, api_key): |
|
|
"""Load credentials using direct API key lookup.""" |
|
|
self.api_key = api_key |
|
|
data = self._load_json_config('ak_sk_mapping.json') |
|
|
self.api_secret = data.get(api_key, "") |
|
|
self._validate_credentials() |
|
|
|
|
|
def _load_credentials_by_cost_attribution(self, cost_attribution): |
|
|
"""Load credentials using cost attribution lookup.""" |
|
|
data = self._load_json_config('ai_flow_ak_sk_mapping.json') |
|
|
credentials = data.get(cost_attribution, {}) |
|
|
self.api_key = credentials.get("ak", "") |
|
|
self.api_secret = credentials.get("sk", "") |
|
|
self.token = credentials.get("token", "") |
|
|
self._validate_credentials() |
|
|
|
|
|
def _initialize_urls(self): |
|
|
"""Initialize API URLs with credentials.""" |
|
|
base_url = "https://openapi.mtlab.meitu.com/v1" |
|
|
auth_params = f"api_key={self.api_key}&api_secret={self.api_secret}" |
|
|
|
|
|
|
|
|
self.query_url = f"{base_url}/query?{auth_params}" |
|
|
|
|
|
self.url = f"{base_url}/{self.api_name}?{auth_params}" |
|
|
|
|
|
def fetch_response(self, msg_id): |
|
|
url = f"{self.query_url}&msg_id={msg_id}" |
|
|
headers = {"Content-Type": "application/json"} |
|
|
data={"msg_id": msg_id} |
|
|
response = requests.post(url, json=data, headers=headers) |
|
|
response.raise_for_status() |
|
|
logging.info(f"fetch_response: {response.json()}") |
|
|
return response.json() |
|
|
|
|
|
@retry( |
|
|
stop=stop_after_delay(1000), |
|
|
wait=wait_fixed(1), |
|
|
retry=retry_if_result(lambda res: res.get("error_code") == 4) |
|
|
) |
|
|
def get_res(self, msg_id): |
|
|
""" |
|
|
Fetch the result with automatic retries and timeout. |
|
|
""" |
|
|
return self.fetch_response(msg_id) |
|
|
|
|
|
def async_request(self, data: dict, max_retries: int = 20): |
|
|
""" |
|
|
发起异步请求并轮询获取结果 |
|
|
|
|
|
Args: |
|
|
data (dict): 请求数据体,需包含图片URL等必要参数 |
|
|
注意:异步接口仅支持图片URL格式,不支持base64编码图片 |
|
|
max_retries (int): 最大重试次数,默认20次 |
|
|
|
|
|
Returns: |
|
|
dict: 通过msg_id轮询获取的最终处理结果 |
|
|
|
|
|
流程说明: |
|
|
1. 发送POST请求获取msg_id |
|
|
2. 通过msg_id轮询获取最终响应数据 |
|
|
""" |
|
|
headers = {"Content-Type": "application/json"} |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
response = requests.post(self.url, json=data, headers=headers) |
|
|
response.raise_for_status() |
|
|
|
|
|
msg_id = response.json().get("msg_id", "") |
|
|
if not msg_id: |
|
|
continue |
|
|
|
|
|
result = self.get_res(msg_id=msg_id) |
|
|
if result.get("error_code", 0) == 0: |
|
|
return result |
|
|
else: |
|
|
raise Exception(result.get("error_msg", "")) |
|
|
|
|
|
except requests.RequestException: |
|
|
if attempt == max_retries - 1: |
|
|
raise |
|
|
continue |
|
|
|
|
|
return result if 'result' in locals() else {"error_code": -1, "error_msg": "Max retries exceeded"} |
|
|
|
|
|
def request(self, data: dict): |
|
|
""" |
|
|
发起同步请求并返回即时响应 |
|
|
|
|
|
Params: |
|
|
data (dict): 请求数据体,可包含base64编码的图片数据 |
|
|
注意:同步接口支持base64编码图片,异步接口需使用URL |
|
|
|
|
|
Return: |
|
|
dict: 原始响应JSON数据 |
|
|
""" |
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
response = requests.post(self.url, json=data, headers=headers) |
|
|
response.raise_for_status() |
|
|
return response.json() |