Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import os | |
import random | |
import re | |
import json | |
import time | |
import aiohttp | |
import asyncio | |
from aiohttp import web | |
import unicodedata | |
from dataclasses import dataclass, asdict | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Config: | |
# SILICONFLOW-Key | |
AUTH_TOKEN: str = 'sk-vqfmwdzfgmbpxyqyqjtpnmocljhlzbprrxcmlovlnpniqjqq' | |
# One-API/New-API 中转地址 | |
OPENAI_CHAT_API: str = 'https://xxxxxxxxxxxx/v1/chat/completions' | |
# key | |
OPENAI_CHAT_API_KEY: str = 'sk-xxxxxxxxxx' | |
# 默认的翻译模型 | |
DEFAULT_TRANSLATE_MODEL: str = 'deepseek-chat' | |
# 增强的翻译模型 | |
DEFAULT_PROMPT_MODEL: str = 'Qwen2-72B-Instruct' | |
config = Config() | |
URLS = { | |
'API_FLUX1_API4GPT_COM': 'https://api-flux1.api4gpt.com', | |
'FLUXAIWEB_COM_TOKEN': 'https://fluxaiweb.com/flux/getToken', | |
'FLUXAIWEB_COM_GENERATE': 'https://fluxaiweb.com/flux/generateImage', | |
'FLUXIMG_COM': 'https://fluximg.com/api/image/generateImage', | |
'API_SILICONFLOW_CN': 'https://api.siliconflow.cn/v1/chat/completions' | |
} | |
URL_MAP = { | |
'flux': "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image", | |
'sd3': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image", | |
'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image", | |
'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/text-to-image", | |
'sdt': "https://api.siliconflow.cn/v1/stabilityai/sd-turbo/text-to-image", | |
'sdxlt': "https://api.siliconflow.cn/v1/stabilityai/sdxl-turbo/text-to-image", | |
'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/text-to-image" | |
} | |
IMG_URL_MAP = { | |
'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/image-to-image", | |
'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/image-to-image", | |
'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/image-to-image", | |
'pm': "https://api.siliconflow.cn/v1/TencentARC/PhotoMaker/image-to-image" | |
} | |
RATIO_MAP = { | |
"1:1": "1024x1024", | |
"1:2": "1024x2048", | |
"3:2": "1536x1024", | |
"4:3": "1536x2048", | |
"16:9": "2048x1152", | |
"9:16": "1152x2048" | |
} | |
SYSTEM_ASSISTANT = """作为 Stable Diffusion Prompt 提示词专家,您将从关键词中创建提示,通常来自 Danbooru 等数据库。 | |
提示通常描述图像,使用常见词汇,按重要性排列,并用逗号分隔。避免使用"-"或".",但可以接受空格和自然语言。避免词汇重复。 | |
为了强调关键词,请将其放在括号中以增加其权重。例如,"(flowers)"将'flowers'的权重增加1.1倍,而"(((flowers)))"将其增加1.331倍。使用"(flowers:1.5)"将'flowers'的权重增加1.5倍。只为重要的标签增加权重。 | |
提示包括三个部分:**前缀**(质量标签+风格词+效果器)+ **主题**(图像的主要焦点)+ **场景**(背景、环境)。 | |
* 前缀影响图像质量。像"masterpiece"、"best quality"、"4k"这样的标签可以提高图像的细节。像"illustration"、"lensflare"这样的风格词定义图像的风格。像"bestlighting"、"lensflare"、"depthoffield"这样的效果器会影响光照和深度。 | |
* 主题是图像的主要焦点,如角色或场景。对主题进行详细描述可以确保图像丰富而详细。增加主题的权重以增强其清晰度。对于角色,描述面部、头发、身体、服装、姿势等特征。 | |
* 场景描述环境。没有场景,图像的背景是平淡的,主题显得过大。某些主题本身包含场景(例如建筑物、风景)。像"花草草地"、"阳光"、"河流"这样的环境词可以丰富场景。你的任务是设计图像生成的提示。请按照以下步骤进行操作: | |
1. 我会发送给您一个图像场景。需要你生成详细的图像描述 | |
2. 图像描述必须是英文,输出为Positive Prompt。 | |
示例: | |
我发送:二战时期的护士。 | |
您回复只回复: | |
A WWII-era nurse in a German uniform, holding a wine bottle and stethoscope, sitting at a table in white attire, with a table in the background, masterpiece, best quality, 4k, illustration style, best lighting, depth of field, detailed character, detailed environment. | |
""" | |
async def select_random_image_generator(): | |
generators = [generate_image1, generate_image2, generate_image3, generate_image4] | |
generator = random.choice(generators) | |
if generator == generate_image4: | |
return lambda prompt, size, model: generator(prompt, size, model) | |
else: | |
return lambda prompt, size, model=None: generator(prompt, size) | |
def extract_size_and_model_from_prompt(prompt): | |
size_match = re.search(r'--ar\s+(\S+)', prompt) | |
model_match = re.search(r'--m\s+(\S+)', prompt) | |
size = size_match.group(1) if size_match else '1:1' | |
model = model_match.group(1) if model_match else '' | |
clean_prompt = re.sub(r'--ar\s+\S+', '', prompt).strip() | |
clean_prompt = re.sub(r'--m\s+\S+', '', clean_prompt).strip() | |
return {'size': size, 'model': model, 'clean_prompt': clean_prompt} | |
def is_chinese(char): | |
return 'CJK' in unicodedata.name(char, '') | |
async def translate_prompt(prompt): | |
if all(not is_chinese(char) for char in prompt): | |
logger.info('Prompt is already in English, skipping translation') | |
return prompt | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post(config.OPENAI_CHAT_API, json={ | |
'model': config.DEFAULT_TRANSLATE_MODEL, # 使用config中的model | |
'messages': [ | |
{'role': 'system', 'content': SYSTEM_ASSISTANT}, | |
{'role': 'user', 'content': prompt} | |
], | |
}, headers={ | |
'Content-Type': 'application/json', | |
'Authorization': f'Bearer {config.OPENAI_CHAT_API_KEY}' | |
}) as response: | |
if response.status != 200: | |
error_text = await response.text() | |
logger.error(f'HTTP error! status: {response.status}, body: {error_text}') | |
raise Exception(f'HTTP error! status: {response.status}') | |
if 'application/json' not in response.headers.get('Content-Type', ''): | |
error_text = await response.text() | |
logger.error(f'Unexpected content type: {response.headers.get("Content-Type")}, body: {error_text}') | |
raise Exception(f'Unexpected content type: {response.headers.get("Content-Type")}') | |
result = await response.json() | |
return result['choices'][0]['message']['content'] | |
except Exception as e: | |
logger.error('Translation error:', e) | |
return prompt | |
async def handle_request(request): | |
if request.method != 'POST' or not request.url.path.endswith('/v1/chat/completions'): | |
return web.Response(text='Not Found', status=404) | |
try: | |
data = await request.json() | |
messages = data.get('messages', []) | |
stream = data.get('stream', False) | |
user_message = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), None) | |
if not user_message: | |
return web.json_response({'error': "未找到用户消息"}, status=400) | |
size_and_model = extract_size_and_model_from_prompt(user_message) | |
translated_prompt = await translate_prompt(size_and_model['clean_prompt']) | |
selected_generator = await select_random_image_generator() | |
attempts = 0 | |
max_attempts = 3 | |
while attempts < max_attempts: | |
try: | |
image_data = await selected_generator(translated_prompt, size_and_model['size'], | |
size_and_model['model']) | |
break | |
except Exception as e: | |
logger.error(f"Error generating image with generator {selected_generator.__name__}: {e}") | |
selected_generator = await select_random_image_generator() | |
attempts += 1 | |
if attempts == max_attempts: | |
logger.error("Failed to generate image after multiple attempts") | |
return web.json_response({'error': "生成图像失败"}, status=500) | |
unique_id = f"chatcmpl-{int(time.time())}" | |
created_timestamp = int(time.time()) | |
model_name = "flux" | |
system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9)) | |
if stream: | |
return await handle_stream_response(request, unique_id, image_data, size_and_model['clean_prompt'], | |
translated_prompt, size_and_model['size'], created_timestamp, | |
model_name, system_fingerprint) | |
else: | |
return handle_non_stream_response(unique_id, image_data, size_and_model['clean_prompt'], translated_prompt, | |
size_and_model['size'], created_timestamp, model_name, system_fingerprint) | |
except Exception as e: | |
logger.error('Error handling request:', e) | |
return web.json_response({'error': f"处理请求失败: {str(e)}"}, status=500) | |
async def handle_stream_response(request, unique_id, image_data, original_prompt, translated_prompt, size, created, | |
model, system_fingerprint): | |
logger.debug("Starting stream response") | |
response = web.StreamResponse( | |
status=200, | |
reason='OK', | |
headers={ | |
'Content-Type': 'text/event-stream', | |
'Cache-Control': 'no-cache', | |
'Connection': 'keep-alive' | |
} | |
) | |
await response.prepare(request) | |
logger.debug("Response prepared") | |
chunks = [ | |
f"原始提示词:\n{original_prompt}\n", | |
f"翻译后的提示词:\n{translated_prompt}\n", | |
f"图像规格:{size}\n", | |
"正在根据提示词生成图像...\n", | |
"图像正在处理中...\n", | |
"即将完成...\n", | |
f"生成成功!\n图像生成完毕,以下是结果:\n\n![生成的图像]({image_data['data'][0]['url']})" | |
] | |
for i, chunk in enumerate(chunks): | |
json_chunk = json.dumps({ | |
"id": unique_id, | |
"object": "chat.completion.chunk", | |
"created": created, | |
"model": model, | |
"system_fingerprint": system_fingerprint, | |
"choices": [{ | |
"index": 0, | |
"delta": {"content": chunk}, | |
"logprobs": None, | |
"finish_reason": None | |
}] | |
}) | |
try: | |
await response.write(f"data: {json_chunk}\n\n".encode('utf-8')) | |
logger.debug(f"Chunk {i + 1} sent") | |
except Exception as e: | |
logger.error(f"Error sending chunk {i + 1}: {str(e)}") | |
await asyncio.sleep(0.5) # 模拟生成时间 | |
final_chunk = json.dumps({ | |
"id": unique_id, | |
"object": "chat.completion.chunk", | |
"created": created, | |
"model": model, | |
"system_fingerprint": system_fingerprint, | |
"choices": [{ | |
"index": 0, | |
"delta": {}, | |
"logprobs": None, | |
"finish_reason": "stop" | |
}] | |
}) | |
try: | |
await response.write(f"data: {final_chunk}\n\n".encode('utf-8')) | |
logger.debug("Final chunk sent") | |
except Exception as e: | |
logger.error(f"Error sending final chunk: {str(e)}") | |
await response.write_eof() | |
logger.debug("Stream response completed") | |
return response | |
def handle_non_stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model, | |
system_fingerprint): | |
content = ( | |
f"原始提示词:{original_prompt}\n" | |
f"翻译后的提示词:{translated_prompt}\n" | |
f"图像规格:{size}\n" | |
f"图像生成成功!\n" | |
f"以下是结果:\n\n" | |
f"![生成的图像]({image_data['data'][0]['url']})" | |
) | |
response = { | |
'id': unique_id, | |
'object': "chat.completion", | |
'created': created, | |
'model': model, | |
'system_fingerprint': system_fingerprint, | |
'choices': [{ | |
'index': 0, | |
'message': { | |
'role': "assistant", | |
'content': content | |
}, | |
'finish_reason': "stop" | |
}], | |
'usage': { | |
'prompt_tokens': len(original_prompt), | |
'completion_tokens': len(content), | |
'total_tokens': len(original_prompt) + len(content) | |
} | |
} | |
return web.json_response(response) | |
async def generate_image1(prompt, size): | |
# 调用 get_prompt 函数来增强提示词 | |
enhanced_prompt = await get_prompt(prompt) | |
prompt_without_spaces = enhanced_prompt.replace(" ", "") | |
image_url = f"{URLS['API_FLUX1_API4GPT_COM']}/?prompt={prompt_without_spaces}&size={size}" | |
return { | |
'data': [{'url': image_url}], | |
'size': size | |
} | |
async def generate_image2(prompt, size): | |
random_ip = generate_random_ip() | |
# 调用 get_prompt 来增强提示词 | |
enhanced_prompt = await get_prompt(prompt) | |
async with aiohttp.ClientSession() as session: | |
async with session.get(URLS['FLUXAIWEB_COM_TOKEN'], | |
headers={'X-Forwarded-For': random_ip}) as token_response: | |
token_data = await token_response.json() | |
token = token_data['data']['token'] | |
async with session.post(URLS['FLUXAIWEB_COM_GENERATE'], headers={ | |
'Content-Type': 'application/json', | |
'token': token, | |
'X-Forwarded-For': random_ip | |
}, json={ | |
'prompt': enhanced_prompt, | |
'aspectRatio': size, | |
'outputFormat': 'webp', | |
'numOutputs': 1, | |
'outputQuality': 90 | |
}) as image_response: | |
image_data = await image_response.json() | |
return { | |
'data': [{'url': image_data['data']['image']}], | |
'size': size | |
} | |
async def generate_image3(prompt, size): | |
json_body = { | |
'textStr': prompt, | |
'model': "black-forest-labs/flux-schnell", | |
'size': size | |
} | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post(URLS['FLUXIMG_COM'], data=json.dumps(json_body), | |
headers={'Content-Type': 'text/plain;charset=UTF-8'}) as response: | |
if response.status == 200: | |
image_url = await response.text() | |
return { | |
'data': [{'url': image_url}], | |
'size': size | |
} | |
else: | |
logger.error( | |
f"Unexpected response status: {response.status}, response text: {await response.text()}") | |
except aiohttp.ClientConnectorError as e: | |
logger.error(f"Connection error on attempt {attempt + 1}: {e}") | |
await asyncio.sleep(2 ** attempt) # Exponential backoff | |
logger.error("Failed to generate image after multiple attempts") | |
return { | |
'data': [{'url': "https://via.placeholder.com/640x480/428675/ffffff?text=Error"}], | |
'size': size | |
} | |
async def generate_image4(prompt, size, model): | |
if not config.AUTH_TOKEN: | |
raise Exception("AUTH_TOKEN is required for this method") | |
api_url = URL_MAP.get(model, URL_MAP['flux']) | |
clean_prompt = re.sub(r'--m\s+\S+', '', prompt).strip() | |
# 调用 get_prompt 函数来增强提示词 | |
enhanced_prompt = await get_prompt(clean_prompt) | |
json_body = { | |
'prompt': enhanced_prompt, | |
'image_size': RATIO_MAP.get(size, "1024x1024"), | |
'num_inference_steps': 50 | |
} | |
if model and model != "flux": | |
json_body['batch_size'] = 1 | |
json_body['guidance_scale'] = 7.5 | |
if model in ["sdt", "sdxlt"]: | |
json_body['num_inference_steps'] = 6 | |
json_body['guidance_scale'] = 1 | |
elif model == "sdxll": | |
json_body['num_inference_steps'] = 4 | |
json_body['guidance_scale'] = 1 | |
async with aiohttp.ClientSession() as session: | |
async with session.post(api_url, headers={ | |
'authorization': config.AUTH_TOKEN if config.AUTH_TOKEN.startswith( | |
'Bearer ') else f'Bearer {config.AUTH_TOKEN}', | |
'Accept': 'application/json', | |
'Content-Type': 'application/json' | |
}, json=json_body) as response: | |
if response.status != 200: | |
raise Exception(f'Unexpected response {response.status}') | |
json_response = await response.json() | |
return { | |
'data': [{'url': json_response['images'][0]['url']}], | |
'size': size | |
} | |
def generate_random_ip(): | |
return f"{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}" | |
async def get_prompt(prompt): | |
logger.info(f"Original Prompt: {prompt}") # 记录输入的原始提示词 | |
request_body_json = json.dumps({ | |
'model': config.DEFAULT_PROMPT_MODEL, # 使用config中的model | |
'messages': [ | |
{ | |
'role': "system", | |
'content': SYSTEM_ASSISTANT | |
}, | |
{ | |
'role': "user", | |
'content': prompt | |
} | |
], | |
'stream': False, | |
'max_tokens': 512, | |
'temperature': 0.7, | |
'top_p': 0.7, | |
'top_k': 50, | |
'frequency_penalty': 0.5, | |
'n': 1 | |
}) | |
# 打印出请求的详细信息 | |
logger.debug(f"Request Body: {request_body_json}") | |
request_headers = { | |
'accept': 'application/json', | |
'authorization': config.OPENAI_CHAT_API_KEY if config.OPENAI_CHAT_API_KEY.startswith( | |
'Bearer ') else f'Bearer {config.OPENAI_CHAT_API_KEY}', | |
'content-type': 'application/json' | |
} | |
logger.debug(f"Request Headers: {request_headers}") | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post(config.OPENAI_CHAT_API, headers=request_headers, | |
data=request_body_json) as response: | |
if response.status != 200: | |
error_text = await response.text() | |
logger.error(f"Failed to get response, status code: {response.status}, response: {error_text}") | |
return prompt | |
json_response = await response.json() | |
logger.debug(f"API Response: {json_response}") # 记录API的完整响应 | |
if 'choices' in json_response and len(json_response['choices']) > 0: | |
enhanced_prompt = json_response['choices'][0]['message']['content'] | |
logger.info(f"Enhanced Prompt: {enhanced_prompt}") # 记录增强后的提示词 | |
return enhanced_prompt | |
else: | |
logger.warning("No enhanced prompt found in the response, returning original prompt.") | |
return prompt | |
except Exception as e: | |
logger.error(f"Exception occurred: {e}") | |
return prompt | |
app = web.Application() | |
app.router.add_post('/hf/v1/chat/completions', handle_request) | |
if __name__ == '__main__': | |
web.run_app(app, host='0.0.0.0', port=7860) |