Spaces:
Running
Running
# Copyright (2024) Bytedance Ltd. and/or its affiliates | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import requests | |
import base64 | |
import json | |
import io | |
import os | |
import hashlib | |
import hmac | |
import time | |
import random | |
from PIL import Image, ImageFilter | |
from loguru import logger | |
# ζ₯ε£ URL | |
t2i_url = 'https://magicarena.bytedance.com/api/evaluate/v1/algo/process' | |
APP_KEY = os.environ['APP_KEY'] | |
SECRET_KEY = os.environ["SECRET_KEY"] | |
def get_auth(app_key, secret_key): | |
# ηζιζΊζ°δ½δΈΊ nonce | |
nonce = str(random.randint(0, 2**31 - 1)) | |
# θ·εε½εζΆι΄ζ³ | |
timestamp = str(int(time.time())) | |
# θ°η¨ get_sign ε½ζ°ηζηΎε | |
sign = get_sign(nonce, timestamp, secret_key) | |
return { | |
"AppKey": app_key, | |
"Timestamp":timestamp, | |
"Nonce":nonce, | |
"Sign": sign} | |
def get_sign(nonce, timestamp, secret_key): | |
keys = [nonce, secret_key, timestamp] | |
keys.sort() | |
key_str = ''.join(keys) | |
sha1_hash = hashlib.sha1() | |
sha1_hash.update(key_str.encode('utf-8')) | |
signature = sha1_hash.hexdigest() | |
return signature.lower() | |
class SeedT2ICaller(): | |
def __init__(self, cfg, *args, **kwargs): | |
self.cfg = cfg | |
def generate(self, text, *args, **kwargs): | |
try: | |
logger.info("Generate images ...") | |
req_json = json.dumps({ | |
"prompt": str(text), | |
"use_sr": True, | |
"model_version": "general_v2.0_L", | |
"req_schedule_conf": "general_v20_9B_pe" | |
# "width": 64, | |
# "height": 64 | |
}) | |
authInfo = get_auth(APP_KEY, SECRET_KEY) | |
logger.info(f"{req_json}") | |
# θ―·ζ±ει | |
response = requests.post( | |
t2i_url, | |
data=json.dumps({ | |
'AlgoType': 1, | |
'ReqJson': req_json, | |
'AuthInfo': authInfo | |
}) | |
) | |
logger.info(f"header: {response.headers}") | |
if response.status_code != 200: | |
return None, False | |
resp = response.json() | |
if resp.get('code',{}) != 0: | |
logger.info(f"response error {resp}") | |
return None, False | |
binary_data1 = resp.get('data', {}).get('BinaryData') | |
binary_data = binary_data1[0] | |
#logger.info(f"binary_data: {binary_data}") | |
image = Image.open(io.BytesIO(base64.b64decode(binary_data))) | |
image = image.resize((self.cfg['resolution'], self.cfg['resolution'])) | |
return image, True | |
except Exception as e: | |
logger.exception("An error occurred during image generation.") | |
return None, False | |
class SeedEditCaller(): | |
def __init__(self, cfg, *args, **kwargs): | |
self.cfg = cfg | |
def edit(self, image, edit, cfg_scale=0.5, *args, **kwargs): | |
try: | |
image_bytes = io.BytesIO() | |
image.save(image_bytes, format='JPEG') # ζ format='PNG' | |
logger.info("Edit images ...") | |
req_json = json.dumps({ | |
"prompt": str(edit), | |
"model_version": "byteedit_v2.0", | |
"scale": cfg_scale, | |
"use_sr": True | |
}) | |
logger.info(f"{req_json}") | |
binary =base64.b64encode(image_bytes.getvalue()).decode('utf-8') | |
# θ―·ζ±ει | |
response = requests.post( | |
t2i_url, | |
data=json.dumps({ | |
'AlgoType': 2, | |
'ReqJson': req_json, | |
'BinaryData': [binary], | |
'AuthInfo': get_auth(APP_KEY, SECRET_KEY) | |
}) | |
) | |
logger.info(f"header: {response.headers}") | |
if response.status_code != 200: | |
return None, False | |
resp = response.json() | |
if resp.get('code',{}) != 0: | |
logger.info(f"response error {resp}") | |
return None, False | |
binary_data = resp.get('data', {}).get('BinaryData') | |
image = Image.open(io.BytesIO(base64.b64decode(binary_data[0]))) | |
return image, True | |
except Exception as e: | |
logger.exception("An error occurred during image generation.") | |
return None, False | |
if __name__ == "__main__": | |
cfg_t2i = { | |
"resolution": 611 | |
} | |
model_t2i = SeedT2ICaller(cfg_t2i) | |
image, _ = model_t2i.generate("a beautiful girl") | |
model_edit = SeedEditCaller(cfg_t2i) | |
model_edit.edit(image, edit="please edit to a good man") |