Spaces:
xiao7710
/
Runtime error

ming-v5 / main.py
xiao7710's picture
Upload 2 files
f19100a
raw
history blame
3.32 kB
from __future__ import annotations
import functools
import numpy as np
import onnxruntime as rt
import pandas as pd
import PIL.Image
import requests
import base64
from Utils import dbimutils
from fastapi import FastAPI, File, UploadFile
from io import BytesIO
from Crypto.Cipher import AES
app = FastAPI()
def load_model():
path = "./conv2/model.onnx"
loaded_models = rt.InferenceSession(path)
path = "./conv2/selected_tags.csv"
df = pd.read_csv(path)
tag_names = df["name"].tolist()
general_indexes = list(np.where(df["category"] == 0)[0])
return functools.partial(predict, tag_names=tag_names, general_indexes=general_indexes,models=loaded_models)
def string_aes_decode_v1(text, key_text):
key = base64.b64decode(key_text)
ciphertext = base64.b64decode(text.encode('utf-8'))
iv = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f'
cipher = AES.new(key, AES.MODE_CFB, iv=iv, segment_size=128)
plaintext = cipher.decrypt(ciphertext)
plaintext = plaintext.rstrip(b'\0')
return plaintext.decode('utf-8')
def predict(image: PIL.Image.Image, models: rt.InferenceSession, tag_names: list[str], general_indexes: list[np.int64]):
general_threshold=0.35
_, height, width, _ = models.get_inputs()[0].shape
image = image.convert("RGBA")
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
new_image.paste(image, mask=image)
image = new_image.convert("RGB")
image = np.asarray(image)
image = image[:, :, ::-1]
image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
input_name = models.get_inputs()[0].name
label_name = models.get_outputs()[0].name
probs = models.run([label_name], {input_name: image})[0]
labels = list(zip(tag_names, probs[0].astype(float)))
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
a = (", ".join(list(b.keys())).replace("_", " ").replace("(", "\(").replace(")", "\)"))
return (a)
predict_func = load_model()
@app.get("/")
async def read_root():
return "from https://replicate.com/"
@app.get("/v1")
async def predict_endpoint(image_url: str):
# decode
try:
a = base64.b64decode("aHR0cDovL2ltZzJ0eHQtY2RuLnNuc3hpb25nLmNvbS8=").decode('utf-8')
b = string_aes_decode_v1(image_url, "S0tnaDFRODc2UDg3VVVtVnpaWHAwN3Rr")
image_url = a + b
except Exception:
image_url = ""
if image_url == "":
return {"code":1001, "result": "image error, see more: https://replicate.com/docs"}
# download
response = requests.get(image_url)
if response.status_code != 200:
return {"code":1002, "result": "download error, see more: https://replicate.com/docs"}
# predict
try:
image = PIL.Image.open(BytesIO(response.content)).convert("RGB")
result = predict_func(image)
except Exception:
result = ""
if result == "":
return {"code":1003, "result": "predict error, see more: https://replicate.com/docs"}
return {"code":200, "result": result}