flux-api / main.py
tianlong12's picture
Update main.py
8c57676 verified
# -*- coding: utf-8 -*-
import os
import random
import re
import json
import time
import aiohttp
import asyncio
from aiohttp import web
import unicodedata
from dataclasses import dataclass, asdict
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class Config:
# SILICONFLOW-Key
AUTH_TOKEN: str = 'sk-vqfmwdzfgmbpxyqyqjtpnmocljhlzbprrxcmlovlnpniqjqq'
# One-API/New-API 中转地址
OPENAI_CHAT_API: str = 'https://xxxxxxxxxxxx/v1/chat/completions'
# key
OPENAI_CHAT_API_KEY: str = 'sk-xxxxxxxxxx'
# 默认的翻译模型
DEFAULT_TRANSLATE_MODEL: str = 'deepseek-chat'
# 增强的翻译模型
DEFAULT_PROMPT_MODEL: str = 'Qwen2-72B-Instruct'
config = Config()
URLS = {
'API_FLUX1_API4GPT_COM': 'https://api-flux1.api4gpt.com',
'FLUXAIWEB_COM_TOKEN': 'https://fluxaiweb.com/flux/getToken',
'FLUXAIWEB_COM_GENERATE': 'https://fluxaiweb.com/flux/generateImage',
'FLUXIMG_COM': 'https://fluximg.com/api/image/generateImage',
'API_SILICONFLOW_CN': 'https://api.siliconflow.cn/v1/chat/completions'
}
URL_MAP = {
'flux': "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image",
'sd3': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image",
'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image",
'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/text-to-image",
'sdt': "https://api.siliconflow.cn/v1/stabilityai/sd-turbo/text-to-image",
'sdxlt': "https://api.siliconflow.cn/v1/stabilityai/sdxl-turbo/text-to-image",
'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/text-to-image"
}
IMG_URL_MAP = {
'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/image-to-image",
'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/image-to-image",
'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/image-to-image",
'pm': "https://api.siliconflow.cn/v1/TencentARC/PhotoMaker/image-to-image"
}
RATIO_MAP = {
"1:1": "1024x1024",
"1:2": "1024x2048",
"3:2": "1536x1024",
"4:3": "1536x2048",
"16:9": "2048x1152",
"9:16": "1152x2048"
}
SYSTEM_ASSISTANT = """作为 Stable Diffusion Prompt 提示词专家,您将从关键词中创建提示,通常来自 Danbooru 等数据库。
提示通常描述图像,使用常见词汇,按重要性排列,并用逗号分隔。避免使用"-"或".",但可以接受空格和自然语言。避免词汇重复。
为了强调关键词,请将其放在括号中以增加其权重。例如,"(flowers)"将'flowers'的权重增加1.1倍,而"(((flowers)))"将其增加1.331倍。使用"(flowers:1.5)"将'flowers'的权重增加1.5倍。只为重要的标签增加权重。
提示包括三个部分:**前缀**(质量标签+风格词+效果器)+ **主题**(图像的主要焦点)+ **场景**(背景、环境)。
* 前缀影响图像质量。像"masterpiece"、"best quality"、"4k"这样的标签可以提高图像的细节。像"illustration"、"lensflare"这样的风格词定义图像的风格。像"bestlighting"、"lensflare"、"depthoffield"这样的效果器会影响光照和深度。
* 主题是图像的主要焦点,如角色或场景。对主题进行详细描述可以确保图像丰富而详细。增加主题的权重以增强其清晰度。对于角色,描述面部、头发、身体、服装、姿势等特征。
* 场景描述环境。没有场景,图像的背景是平淡的,主题显得过大。某些主题本身包含场景(例如建筑物、风景)。像"花草草地"、"阳光"、"河流"这样的环境词可以丰富场景。你的任务是设计图像生成的提示。请按照以下步骤进行操作:
1. 我会发送给您一个图像场景。需要你生成详细的图像描述
2. 图像描述必须是英文,输出为Positive Prompt。
示例:
我发送:二战时期的护士。
您回复只回复:
A WWII-era nurse in a German uniform, holding a wine bottle and stethoscope, sitting at a table in white attire, with a table in the background, masterpiece, best quality, 4k, illustration style, best lighting, depth of field, detailed character, detailed environment.
"""
async def select_random_image_generator():
generators = [generate_image1, generate_image2, generate_image3, generate_image4]
generator = random.choice(generators)
if generator == generate_image4:
return lambda prompt, size, model: generator(prompt, size, model)
else:
return lambda prompt, size, model=None: generator(prompt, size)
def extract_size_and_model_from_prompt(prompt):
size_match = re.search(r'--ar\s+(\S+)', prompt)
model_match = re.search(r'--m\s+(\S+)', prompt)
size = size_match.group(1) if size_match else '1:1'
model = model_match.group(1) if model_match else ''
clean_prompt = re.sub(r'--ar\s+\S+', '', prompt).strip()
clean_prompt = re.sub(r'--m\s+\S+', '', clean_prompt).strip()
return {'size': size, 'model': model, 'clean_prompt': clean_prompt}
def is_chinese(char):
return 'CJK' in unicodedata.name(char, '')
async def translate_prompt(prompt):
if all(not is_chinese(char) for char in prompt):
logger.info('Prompt is already in English, skipping translation')
return prompt
try:
async with aiohttp.ClientSession() as session:
async with session.post(config.OPENAI_CHAT_API, json={
'model': config.DEFAULT_TRANSLATE_MODEL, # 使用config中的model
'messages': [
{'role': 'system', 'content': SYSTEM_ASSISTANT},
{'role': 'user', 'content': prompt}
],
}, headers={
'Content-Type': 'application/json',
'Authorization': f'Bearer {config.OPENAI_CHAT_API_KEY}'
}) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f'HTTP error! status: {response.status}, body: {error_text}')
raise Exception(f'HTTP error! status: {response.status}')
if 'application/json' not in response.headers.get('Content-Type', ''):
error_text = await response.text()
logger.error(f'Unexpected content type: {response.headers.get("Content-Type")}, body: {error_text}')
raise Exception(f'Unexpected content type: {response.headers.get("Content-Type")}')
result = await response.json()
return result['choices'][0]['message']['content']
except Exception as e:
logger.error('Translation error:', e)
return prompt
async def handle_request(request):
if request.method != 'POST' or not request.url.path.endswith('/v1/chat/completions'):
return web.Response(text='Not Found', status=404)
try:
data = await request.json()
messages = data.get('messages', [])
stream = data.get('stream', False)
user_message = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), None)
if not user_message:
return web.json_response({'error': "未找到用户消息"}, status=400)
size_and_model = extract_size_and_model_from_prompt(user_message)
translated_prompt = await translate_prompt(size_and_model['clean_prompt'])
selected_generator = await select_random_image_generator()
attempts = 0
max_attempts = 3
while attempts < max_attempts:
try:
image_data = await selected_generator(translated_prompt, size_and_model['size'],
size_and_model['model'])
break
except Exception as e:
logger.error(f"Error generating image with generator {selected_generator.__name__}: {e}")
selected_generator = await select_random_image_generator()
attempts += 1
if attempts == max_attempts:
logger.error("Failed to generate image after multiple attempts")
return web.json_response({'error': "生成图像失败"}, status=500)
unique_id = f"chatcmpl-{int(time.time())}"
created_timestamp = int(time.time())
model_name = "flux"
system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9))
if stream:
return await handle_stream_response(request, unique_id, image_data, size_and_model['clean_prompt'],
translated_prompt, size_and_model['size'], created_timestamp,
model_name, system_fingerprint)
else:
return handle_non_stream_response(unique_id, image_data, size_and_model['clean_prompt'], translated_prompt,
size_and_model['size'], created_timestamp, model_name, system_fingerprint)
except Exception as e:
logger.error('Error handling request:', e)
return web.json_response({'error': f"处理请求失败: {str(e)}"}, status=500)
async def handle_stream_response(request, unique_id, image_data, original_prompt, translated_prompt, size, created,
model, system_fingerprint):
logger.debug("Starting stream response")
response = web.StreamResponse(
status=200,
reason='OK',
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive'
}
)
await response.prepare(request)
logger.debug("Response prepared")
chunks = [
f"原始提示词:\n{original_prompt}\n",
f"翻译后的提示词:\n{translated_prompt}\n",
f"图像规格:{size}\n",
"正在根据提示词生成图像...\n",
"图像正在处理中...\n",
"即将完成...\n",
f"生成成功!\n图像生成完毕,以下是结果:\n\n![生成的图像]({image_data['data'][0]['url']})"
]
for i, chunk in enumerate(chunks):
json_chunk = json.dumps({
"id": unique_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [{
"index": 0,
"delta": {"content": chunk},
"logprobs": None,
"finish_reason": None
}]
})
try:
await response.write(f"data: {json_chunk}\n\n".encode('utf-8'))
logger.debug(f"Chunk {i + 1} sent")
except Exception as e:
logger.error(f"Error sending chunk {i + 1}: {str(e)}")
await asyncio.sleep(0.5) # 模拟生成时间
final_chunk = json.dumps({
"id": unique_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop"
}]
})
try:
await response.write(f"data: {final_chunk}\n\n".encode('utf-8'))
logger.debug("Final chunk sent")
except Exception as e:
logger.error(f"Error sending final chunk: {str(e)}")
await response.write_eof()
logger.debug("Stream response completed")
return response
def handle_non_stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model,
system_fingerprint):
content = (
f"原始提示词:{original_prompt}\n"
f"翻译后的提示词:{translated_prompt}\n"
f"图像规格:{size}\n"
f"图像生成成功!\n"
f"以下是结果:\n\n"
f"![生成的图像]({image_data['data'][0]['url']})"
)
response = {
'id': unique_id,
'object': "chat.completion",
'created': created,
'model': model,
'system_fingerprint': system_fingerprint,
'choices': [{
'index': 0,
'message': {
'role': "assistant",
'content': content
},
'finish_reason': "stop"
}],
'usage': {
'prompt_tokens': len(original_prompt),
'completion_tokens': len(content),
'total_tokens': len(original_prompt) + len(content)
}
}
return web.json_response(response)
async def generate_image1(prompt, size):
# 调用 get_prompt 函数来增强提示词
enhanced_prompt = await get_prompt(prompt)
prompt_without_spaces = enhanced_prompt.replace(" ", "")
image_url = f"{URLS['API_FLUX1_API4GPT_COM']}/?prompt={prompt_without_spaces}&size={size}"
return {
'data': [{'url': image_url}],
'size': size
}
async def generate_image2(prompt, size):
random_ip = generate_random_ip()
# 调用 get_prompt 来增强提示词
enhanced_prompt = await get_prompt(prompt)
async with aiohttp.ClientSession() as session:
async with session.get(URLS['FLUXAIWEB_COM_TOKEN'],
headers={'X-Forwarded-For': random_ip}) as token_response:
token_data = await token_response.json()
token = token_data['data']['token']
async with session.post(URLS['FLUXAIWEB_COM_GENERATE'], headers={
'Content-Type': 'application/json',
'token': token,
'X-Forwarded-For': random_ip
}, json={
'prompt': enhanced_prompt,
'aspectRatio': size,
'outputFormat': 'webp',
'numOutputs': 1,
'outputQuality': 90
}) as image_response:
image_data = await image_response.json()
return {
'data': [{'url': image_data['data']['image']}],
'size': size
}
async def generate_image3(prompt, size):
json_body = {
'textStr': prompt,
'model': "black-forest-labs/flux-schnell",
'size': size
}
max_retries = 3
for attempt in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(URLS['FLUXIMG_COM'], data=json.dumps(json_body),
headers={'Content-Type': 'text/plain;charset=UTF-8'}) as response:
if response.status == 200:
image_url = await response.text()
return {
'data': [{'url': image_url}],
'size': size
}
else:
logger.error(
f"Unexpected response status: {response.status}, response text: {await response.text()}")
except aiohttp.ClientConnectorError as e:
logger.error(f"Connection error on attempt {attempt + 1}: {e}")
await asyncio.sleep(2 ** attempt) # Exponential backoff
logger.error("Failed to generate image after multiple attempts")
return {
'data': [{'url': "https://via.placeholder.com/640x480/428675/ffffff?text=Error"}],
'size': size
}
async def generate_image4(prompt, size, model):
if not config.AUTH_TOKEN:
raise Exception("AUTH_TOKEN is required for this method")
api_url = URL_MAP.get(model, URL_MAP['flux'])
clean_prompt = re.sub(r'--m\s+\S+', '', prompt).strip()
# 调用 get_prompt 函数来增强提示词
enhanced_prompt = await get_prompt(clean_prompt)
json_body = {
'prompt': enhanced_prompt,
'image_size': RATIO_MAP.get(size, "1024x1024"),
'num_inference_steps': 50
}
if model and model != "flux":
json_body['batch_size'] = 1
json_body['guidance_scale'] = 7.5
if model in ["sdt", "sdxlt"]:
json_body['num_inference_steps'] = 6
json_body['guidance_scale'] = 1
elif model == "sdxll":
json_body['num_inference_steps'] = 4
json_body['guidance_scale'] = 1
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers={
'authorization': config.AUTH_TOKEN if config.AUTH_TOKEN.startswith(
'Bearer ') else f'Bearer {config.AUTH_TOKEN}',
'Accept': 'application/json',
'Content-Type': 'application/json'
}, json=json_body) as response:
if response.status != 200:
raise Exception(f'Unexpected response {response.status}')
json_response = await response.json()
return {
'data': [{'url': json_response['images'][0]['url']}],
'size': size
}
def generate_random_ip():
return f"{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
async def get_prompt(prompt):
logger.info(f"Original Prompt: {prompt}") # 记录输入的原始提示词
request_body_json = json.dumps({
'model': config.DEFAULT_PROMPT_MODEL, # 使用config中的model
'messages': [
{
'role': "system",
'content': SYSTEM_ASSISTANT
},
{
'role': "user",
'content': prompt
}
],
'stream': False,
'max_tokens': 512,
'temperature': 0.7,
'top_p': 0.7,
'top_k': 50,
'frequency_penalty': 0.5,
'n': 1
})
# 打印出请求的详细信息
logger.debug(f"Request Body: {request_body_json}")
request_headers = {
'accept': 'application/json',
'authorization': config.OPENAI_CHAT_API_KEY if config.OPENAI_CHAT_API_KEY.startswith(
'Bearer ') else f'Bearer {config.OPENAI_CHAT_API_KEY}',
'content-type': 'application/json'
}
logger.debug(f"Request Headers: {request_headers}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(config.OPENAI_CHAT_API, headers=request_headers,
data=request_body_json) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Failed to get response, status code: {response.status}, response: {error_text}")
return prompt
json_response = await response.json()
logger.debug(f"API Response: {json_response}") # 记录API的完整响应
if 'choices' in json_response and len(json_response['choices']) > 0:
enhanced_prompt = json_response['choices'][0]['message']['content']
logger.info(f"Enhanced Prompt: {enhanced_prompt}") # 记录增强后的提示词
return enhanced_prompt
else:
logger.warning("No enhanced prompt found in the response, returning original prompt.")
return prompt
except Exception as e:
logger.error(f"Exception occurred: {e}")
return prompt
app = web.Application()
app.router.add_post('/hf/v1/chat/completions', handle_request)
if __name__ == '__main__':
web.run_app(app, host='0.0.0.0', port=7860)