Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- main.py +38 -6
- requirements.txt +2 -1
main.py
CHANGED
@@ -4,8 +4,12 @@ import numpy as np
|
|
4 |
import onnxruntime as rt
|
5 |
import pandas as pd
|
6 |
import PIL.Image
|
|
|
|
|
7 |
from Utils import dbimutils
|
8 |
from fastapi import FastAPI, File, UploadFile
|
|
|
|
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
@@ -18,6 +22,15 @@ def load_model():
|
|
18 |
general_indexes = list(np.where(df["category"] == 0)[0])
|
19 |
return functools.partial(predict, tag_names=tag_names, general_indexes=general_indexes,models=loaded_models)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def predict(image: PIL.Image.Image, models: rt.InferenceSession, tag_names: list[str], general_indexes: list[np.int64]):
|
22 |
general_threshold=0.35
|
23 |
_, height, width, _ = models.get_inputs()[0].shape
|
@@ -46,10 +59,29 @@ predict_func = load_model()
|
|
46 |
|
47 |
@app.get("/")
|
48 |
async def read_root():
|
49 |
-
return "
|
50 |
|
51 |
-
@app.
|
52 |
-
async def predict_endpoint(
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import onnxruntime as rt
|
5 |
import pandas as pd
|
6 |
import PIL.Image
|
7 |
+
import requests
|
8 |
+
import base64
|
9 |
from Utils import dbimutils
|
10 |
from fastapi import FastAPI, File, UploadFile
|
11 |
+
from io import BytesIO
|
12 |
+
from Crypto.Cipher import AES
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
|
|
22 |
general_indexes = list(np.where(df["category"] == 0)[0])
|
23 |
return functools.partial(predict, tag_names=tag_names, general_indexes=general_indexes,models=loaded_models)
|
24 |
|
25 |
+
def string_aes_decode_v1(text, key_text):
|
26 |
+
key = base64.b64decode(key_text)
|
27 |
+
ciphertext = base64.b64decode(text.encode('utf-8'))
|
28 |
+
iv = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f'
|
29 |
+
cipher = AES.new(key, AES.MODE_CFB, iv=iv, segment_size=128)
|
30 |
+
plaintext = cipher.decrypt(ciphertext)
|
31 |
+
plaintext = plaintext.rstrip(b'\0')
|
32 |
+
return plaintext.decode('utf-8')
|
33 |
+
|
34 |
def predict(image: PIL.Image.Image, models: rt.InferenceSession, tag_names: list[str], general_indexes: list[np.int64]):
|
35 |
general_threshold=0.35
|
36 |
_, height, width, _ = models.get_inputs()[0].shape
|
|
|
59 |
|
60 |
@app.get("/")
|
61 |
async def read_root():
|
62 |
+
return "from https://replicate.com/"
|
63 |
|
64 |
+
@app.get("/v1")
|
65 |
+
async def predict_endpoint(image_url: str):
|
66 |
+
# decode
|
67 |
+
try:
|
68 |
+
a = base64.b64decode("aHR0cDovL2ltZzJ0eHQtY2RuLnNuc3hpb25nLmNvbS8=").decode('utf-8')
|
69 |
+
b = string_aes_decode_v1(image_url, "S0tnaDFRODc2UDg3VVVtVnpaWHAwN3Rr")
|
70 |
+
image_url = a + b
|
71 |
+
except Exception:
|
72 |
+
image_url = ""
|
73 |
+
if image_url == "":
|
74 |
+
return {"code":1001, "result": "image error, see more: https://replicate.com/docs"}
|
75 |
+
# download
|
76 |
+
response = requests.get(image_url)
|
77 |
+
if response.status_code != 200:
|
78 |
+
return {"code":1002, "result": "download error, see more: https://replicate.com/docs"}
|
79 |
+
# predict
|
80 |
+
try:
|
81 |
+
image = PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
82 |
+
result = predict_func(image)
|
83 |
+
except Exception:
|
84 |
+
result = ""
|
85 |
+
if result == "":
|
86 |
+
return {"code":1003, "result": "predict error, see more: https://replicate.com/docs"}
|
87 |
+
return {"code":200, "result": result}
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ fastapi==0.74.*
|
|
7 |
requests==2.27.*
|
8 |
uvicorn[standard]==0.17.*
|
9 |
pandas
|
10 |
-
python-multipart
|
|
|
|
7 |
requests==2.27.*
|
8 |
uvicorn[standard]==0.17.*
|
9 |
pandas
|
10 |
+
python-multipart
|
11 |
+
pycryptodome
|