meme-generator / utils.py
ikechan8370's picture
Update utils.py
e7cf850 verified
import asyncio
import hashlib
import inspect
import math
import random
import time
from dataclasses import dataclass
from enum import Enum
from functools import partial, wraps
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
List,
Literal,
Optional,
Protocol,
Tuple,
TypeVar,
)
import httpx
from PIL.Image import Image as IMG
from pil_utils import BuildImage, Text2Image
from pil_utils.types import ColorType, FontStyle, FontWeight
from typing_extensions import ParamSpec
from .config import meme_config
from .exception import MemeGeneratorException
if TYPE_CHECKING:
from .meme import Meme
P = ParamSpec("P")
R = TypeVar("R")
def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
"""一个用于包装 sync function 为 async function 的装饰器
参数:
call: 被装饰的同步函数
"""
@wraps(call)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
loop = asyncio.get_running_loop()
pfunc = partial(call, *args, **kwargs)
result = await loop.run_in_executor(None, pfunc)
return result
return _wrapper
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
"""检查 call 是否是一个 callable 协程函数"""
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
func_ = getattr(call, "__call__", None)
return inspect.iscoroutinefunction(func_)
def save_gif(frames: List[IMG], duration: float) -> BytesIO:
output = BytesIO()
frames[0].save(
output,
format="GIF",
save_all=True,
append_images=frames[1:],
duration=duration * 1000,
loop=0,
disposal=2,
optimize=False,
)
# 没有超出最大大小,直接返回
nbytes = output.getbuffer().nbytes
if nbytes <= meme_config.gif.gif_max_size * 10**6:
return output
# 超出最大大小,帧数超出最大帧数时,缩减帧数
n_frames = len(frames)
gif_max_frames = meme_config.gif.gif_max_frames
if n_frames > gif_max_frames:
index = range(n_frames)
ratio = n_frames / gif_max_frames
index = (int(i * ratio) for i in range(gif_max_frames))
new_duration = duration * ratio
new_frames = [frames[i] for i in index]
return save_gif(new_frames, new_duration)
# 超出最大大小,帧数没有超出最大帧数时,缩小尺寸
new_frames = [
frame.resize((int(frame.width * 0.9), int(frame.height * 0.9)))
for frame in frames
]
return save_gif(new_frames, duration)
class Maker(Protocol):
def __call__(self, img: BuildImage) -> BuildImage:
...
class GifMaker(Protocol):
def __call__(self, i: int) -> Maker:
...
def get_avg_duration(image: IMG) -> float:
if not getattr(image, "is_animated", False):
return 0
total_duration = 0
for i in range(image.n_frames):
image.seek(i)
total_duration += image.info["duration"]
return total_duration / image.n_frames
def split_gif(image: IMG) -> List[IMG]:
frames: List[IMG] = []
update_mode = "full"
for i in range(image.n_frames):
image.seek(i)
if image.tile: # type: ignore
update_region = image.tile[0][1][2:] # type: ignore
if update_region != image.size:
update_mode = "partial"
break
last_frame: Optional[IMG] = None
for i in range(image.n_frames):
image.seek(i)
frame = image.copy()
if update_mode == "partial" and last_frame:
frame = last_frame.copy().paste(frame)
frames.append(frame)
image.seek(0)
if image.info.__contains__("transparency"):
frames[0].info["transparency"] = image.info["transparency"]
return frames
def make_jpg_or_gif(
img: BuildImage, func: Maker, keep_transparency: bool = False
) -> BytesIO:
"""
制作静图或者动图
:params
* ``img``: 输入图片
* ``func``: 图片处理函数,输入img,返回处理后的图片
* ``keep_transparency``: 传入gif时,是否保留该gif的透明度
"""
image = img.image
if not getattr(image, "is_animated", False):
return func(img).save_jpg()
else:
frames = split_gif(image)
duration = get_avg_duration(image) / 1000
frames = [func(BuildImage(frame)).image for frame in frames]
if keep_transparency:
image.seek(0)
if image.info.__contains__("transparency"):
frames[0].info["transparency"] = image.info["transparency"]
return save_gif(frames, duration)
def make_png_or_gif(
img: BuildImage, func: Maker, keep_transparency: bool = False
) -> BytesIO:
"""
制作静图或者动图
:params
* ``img``: 输入图片
* ``func``: 图片处理函数,输入img,返回处理后的图片
* ``keep_transparency``: 传入gif时,是否保留该gif的透明度
"""
image = img.image
if not getattr(image, "is_animated", False):
return func(img).save_png()
else:
frames = split_gif(image)
duration = get_avg_duration(image) / 1000
frames = [func(BuildImage(frame)).image for frame in frames]
if keep_transparency:
image.seek(0)
if image.info.__contains__("transparency"):
frames[0].info["transparency"] = image.info["transparency"]
return save_gif(frames, duration)
class FrameAlignPolicy(Enum):
"""
要叠加的gif长度大于基准gif时,是否延长基准gif长度以对齐两个gif
"""
no_extend = 0
"""不延长"""
extend_first = 1
"""延长第一帧"""
extend_last = 2
"""延长最后一帧"""
extend_loop = 3
"""以循环方式延长"""
def make_gif_or_combined_gif(
img: BuildImage,
maker: GifMaker,
frame_num: int,
duration: float,
frame_align: FrameAlignPolicy = FrameAlignPolicy.no_extend,
input_based: bool = False,
keep_transparency: bool = False,
) -> BytesIO:
"""
使用静图或动图制作gif
:params
* ``img``: 输入图片,如头像
* ``maker``: 图片处理函数生成,传入第几帧,返回对应的图片处理函数
* ``frame_num``: 目标gif的帧数
* ``duration``: 相邻帧之间的时间间隔,单位为秒
* ``frame_align``: 要叠加的gif长度大于基准gif时,gif长度对齐方式
* ``input_based``: 是否以输入gif为基准合成gif,默认为`False`,即以目标gif为基准
* ``keep_transparency``: 传入gif时,是否保留该gif的透明度
"""
image = img.image
if not getattr(image, "is_animated", False):
return save_gif([maker(i)(img).image for i in range(frame_num)], duration)
frame_num_in = image.n_frames
duration_in = get_avg_duration(image) / 1000
total_duration_in = frame_num_in * duration_in
total_duration = frame_num * duration
if input_based:
frame_num_base = frame_num_in
frame_num_fit = frame_num
duration_base = duration_in
duration_fit = duration
total_duration_base = total_duration_in
total_duration_fit = total_duration
else:
frame_num_base = frame_num
frame_num_fit = frame_num_in
duration_base = duration
duration_fit = duration_in
total_duration_base = total_duration
total_duration_fit = total_duration_in
frame_idxs: List[int] = list(range(frame_num_base))
diff_duration = total_duration_fit - total_duration_base
diff_num = int(diff_duration / duration_base)
if diff_duration >= duration_base:
if frame_align == FrameAlignPolicy.extend_first:
frame_idxs = [0] * diff_num + frame_idxs
elif frame_align == FrameAlignPolicy.extend_last:
frame_idxs += [frame_num_base - 1] * diff_num
elif frame_align == FrameAlignPolicy.extend_loop:
frame_num_total = frame_num_base
# 重复基准gif,直到两个gif总时长之差在1个间隔以内,或总帧数超出最大帧数
while frame_num_total + frame_num_base <= meme_config.gif.gif_max_frames:
frame_num_total += frame_num_base
frame_idxs += list(range(frame_num_base))
multiple = round(frame_num_total * duration_base / total_duration_fit)
if (
math.fabs(
total_duration_fit * multiple - frame_num_total * duration_base
)
<= duration_base
):
break
frames: List[IMG] = []
frame_idx_fit = 0
time_start = 0
for i, idx in enumerate(frame_idxs):
while frame_idx_fit < frame_num_fit:
if (
frame_idx_fit * duration_fit
<= i * duration_base - time_start
< (frame_idx_fit + 1) * duration_fit
):
if input_based:
idx_in = idx
idx_maker = frame_idx_fit
else:
idx_in = frame_idx_fit
idx_maker = idx
func = maker(idx_maker)
image.seek(idx_in)
frames.append(func(BuildImage(image.copy())).image)
break
else:
frame_idx_fit += 1
if frame_idx_fit >= frame_num_fit:
frame_idx_fit = 0
time_start += total_duration_fit
if keep_transparency:
image.seek(0)
if image.info.__contains__("transparency"):
frames[0].info["transparency"] = image.info["transparency"]
return save_gif(frames, duration)
async def translate(text: str, lang_from: str = "auto", lang_to: str = "zh") -> str:
appid = meme_config.translate.baidu_trans_appid
apikey = meme_config.translate.baidu_trans_apikey
if not appid or not apikey:
raise MemeGeneratorException(
"The `baidu_trans_appid` or `baidu_trans_apikey` is not set."
"Please check your config file!"
)
salt = str(round(time.time() * 1000))
sign_raw = appid + text + salt + apikey
sign = hashlib.md5(sign_raw.encode("utf8")).hexdigest()
params = {
"q": text,
"from": lang_from,
"to": lang_to,
"appid": appid,
"salt": salt,
"sign": sign,
}
url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
async with httpx.AsyncClient() as client:
resp = await client.get(url, params=params)
result = resp.json()
return result["trans_result"][0]["dst"]
async def translate_microsoft(text: str, lang_from: str = "zh-CN", lang_to: str = "ja") -> str:
if lang_to == 'jp':
lang_to = 'ja'
params = {
"text": text,
"toLang": lang_to,
}
url = "http://translate.ikechan8370.com/translate"
async with httpx.AsyncClient() as client:
resp = await client.get(url, params=params)
result = resp.json()
return result["translation"]["translation"]
def random_text() -> str:
return random.choice(["刘一", "陈二", "张三", "李四", "王五", "赵六", "孙七", "周八", "吴九", "郑十"])
def random_image() -> BytesIO:
text = random.choice(["😂", "😅", "🤗", "🤤", "🥵", "🥰", "😍", "😭", "😋", "😏"])
return (
BuildImage.new("RGBA", (500, 500), "white")
.draw_text((0, 0, 500, 500), text, max_fontsize=400)
.save_png()
)
@dataclass
class TextProperties:
fill: ColorType = "black"
style: FontStyle = "normal"
weight: FontWeight = "normal"
stroke_width: int = 0
stroke_fill: Optional[ColorType] = None
def default_template(meme: "Meme", number: int) -> str:
return f"{number}. {'/'.join(meme.keywords)}"
def render_meme_list(
meme_list: List[Tuple["Meme", TextProperties]],
*,
template: Callable[["Meme", int], str] = default_template,
order_direction: Literal["row", "column"] = "column",
columns: int = 4,
column_align: Literal["left", "center", "right"] = "left",
item_padding: Tuple[int, int] = (15, 6),
image_padding: Tuple[int, int] = (50, 50),
bg_color: ColorType = "white",
fontsize: int = 30,
fontname: str = "",
fallback_fonts: List[str] = [],
) -> BytesIO:
item_images: List[Text2Image] = []
for i, (meme, properties) in enumerate(meme_list, start=1):
text = template(meme, i)
t2m = Text2Image.from_text(
text,
fontsize=fontsize,
style=properties.style,
weight=properties.weight,
fill=properties.fill,
stroke_width=properties.stroke_width,
stroke_fill=properties.stroke_fill,
fontname=fontname,
fallback_fonts=fallback_fonts,
)
item_images.append(t2m)
char_A = (
Text2Image.from_text(
"A", fontsize=fontsize, fontname=fontname, fallback_fonts=fallback_fonts
)
.lines[0]
.chars[0]
)
num_per_col = math.ceil(len(item_images) / columns)
column_images: List[BuildImage] = []
for col in range(columns):
if order_direction == "column":
images = item_images[col * num_per_col : (col + 1) * num_per_col]
else:
images = [
item_images[num * columns + col]
for num in range((len(item_images) - col - 1) // columns + 1)
]
img_w = max((t2m.width for t2m in images)) + item_padding[0] * 2
img_h = (char_A.ascent + item_padding[1] * 2) * len(images) + char_A.descent
image = BuildImage.new("RGB", (img_w, img_h), bg_color)
y = item_padding[1]
for t2m in images:
if column_align == "left":
x = 0
elif column_align == "center":
x = (img_w - t2m.width - item_padding[0] * 2) // 2
else:
x = img_w - t2m.width - item_padding[0] * 2
t2m.draw_on_image(image.image, (x, y))
y += char_A.ascent + item_padding[1] * 2
column_images.append(image)
img_w = sum((img.width for img in column_images)) + image_padding[0] * 2
img_h = max((img.height for img in column_images)) + image_padding[1] * 2
image = BuildImage.new("RGB", (img_w, img_h), bg_color)
x, y = image_padding
for img in column_images:
image.paste(img, (x, y))
x += img.width
return image.save_jpg()