Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import functools | |
import numpy as np | |
import onnxruntime as rt | |
import pandas as pd | |
import PIL.Image | |
import requests | |
import base64 | |
from Utils import dbimutils | |
from fastapi import FastAPI, File, UploadFile | |
from io import BytesIO | |
from Crypto.Cipher import AES | |
app = FastAPI() | |
def load_model(): | |
path = "./conv2/model.onnx" | |
loaded_models = rt.InferenceSession(path) | |
path = "./conv2/selected_tags.csv" | |
df = pd.read_csv(path) | |
tag_names = df["name"].tolist() | |
general_indexes = list(np.where(df["category"] == 0)[0]) | |
return functools.partial(predict, tag_names=tag_names, general_indexes=general_indexes,models=loaded_models) | |
def string_aes_decode_v1(text, key_text): | |
key = base64.b64decode(key_text) | |
ciphertext = base64.b64decode(text.encode('utf-8')) | |
iv = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f' | |
cipher = AES.new(key, AES.MODE_CFB, iv=iv, segment_size=128) | |
plaintext = cipher.decrypt(ciphertext) | |
plaintext = plaintext.rstrip(b'\0') | |
return plaintext.decode('utf-8') | |
def predict(image: PIL.Image.Image, models: rt.InferenceSession, tag_names: list[str], general_indexes: list[np.int64]): | |
general_threshold=0.35 | |
_, height, width, _ = models.get_inputs()[0].shape | |
image = image.convert("RGBA") | |
new_image = PIL.Image.new("RGBA", image.size, "WHITE") | |
new_image.paste(image, mask=image) | |
image = new_image.convert("RGB") | |
image = np.asarray(image) | |
image = image[:, :, ::-1] | |
image = dbimutils.make_square(image, height) | |
image = dbimutils.smart_resize(image, height) | |
image = image.astype(np.float32) | |
image = np.expand_dims(image, 0) | |
input_name = models.get_inputs()[0].name | |
label_name = models.get_outputs()[0].name | |
probs = models.run([label_name], {input_name: image})[0] | |
labels = list(zip(tag_names, probs[0].astype(float))) | |
general_names = [labels[i] for i in general_indexes] | |
general_res = [x for x in general_names if x[1] > general_threshold] | |
general_res = dict(general_res) | |
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True)) | |
a = (", ".join(list(b.keys())).replace("_", " ").replace("(", "\(").replace(")", "\)")) | |
return (a) | |
predict_func = load_model() | |
async def read_root(): | |
return "from https://replicate.com/" | |
async def predict_endpoint(image_url: str): | |
# decode | |
try: | |
a = base64.b64decode("aHR0cDovL2ltZzJ0eHQtY2RuLnNuc3hpb25nLmNvbS8=").decode('utf-8') | |
b = string_aes_decode_v1(image_url, "S0tnaDFRODc2UDg3VVVtVnpaWHAwN3Rr") | |
image_url = a + b | |
except Exception: | |
image_url = "" | |
if image_url == "": | |
return {"code":1001, "result": "image error, see more: https://replicate.com/docs"} | |
# download | |
response = requests.get(image_url) | |
if response.status_code != 200: | |
return {"code":1002, "result": "download error, see more: https://replicate.com/docs"} | |
# predict | |
try: | |
image = PIL.Image.open(BytesIO(response.content)).convert("RGB") | |
result = predict_func(image) | |
except Exception: | |
result = "" | |
if result == "": | |
return {"code":1003, "result": "predict error, see more: https://replicate.com/docs"} | |
return {"code":200, "result": result} |