Spaces:
xiao7710
/
Runtime error

xiao7710 commited on
Commit
f19100a
1 Parent(s): 0b2d20c

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +38 -6
  2. requirements.txt +2 -1
main.py CHANGED
@@ -4,8 +4,12 @@ import numpy as np
4
  import onnxruntime as rt
5
  import pandas as pd
6
  import PIL.Image
 
 
7
  from Utils import dbimutils
8
  from fastapi import FastAPI, File, UploadFile
 
 
9
 
10
  app = FastAPI()
11
 
@@ -18,6 +22,15 @@ def load_model():
18
  general_indexes = list(np.where(df["category"] == 0)[0])
19
  return functools.partial(predict, tag_names=tag_names, general_indexes=general_indexes,models=loaded_models)
20
 
 
 
 
 
 
 
 
 
 
21
  def predict(image: PIL.Image.Image, models: rt.InferenceSession, tag_names: list[str], general_indexes: list[np.int64]):
22
  general_threshold=0.35
23
  _, height, width, _ = models.get_inputs()[0].shape
@@ -46,10 +59,29 @@ predict_func = load_model()
46
 
47
  @app.get("/")
48
  async def read_root():
49
- return "Hello, world!"
50
 
51
- @app.post("/predict")
52
- async def predict_endpoint(file: UploadFile = File(...)):
53
- image = PIL.Image.open(file.file).convert("RGB")
54
- result = predict_func(image)
55
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import onnxruntime as rt
5
  import pandas as pd
6
  import PIL.Image
7
+ import requests
8
+ import base64
9
  from Utils import dbimutils
10
  from fastapi import FastAPI, File, UploadFile
11
+ from io import BytesIO
12
+ from Crypto.Cipher import AES
13
 
14
  app = FastAPI()
15
 
 
22
  general_indexes = list(np.where(df["category"] == 0)[0])
23
  return functools.partial(predict, tag_names=tag_names, general_indexes=general_indexes,models=loaded_models)
24
 
25
+ def string_aes_decode_v1(text, key_text):
26
+ key = base64.b64decode(key_text)
27
+ ciphertext = base64.b64decode(text.encode('utf-8'))
28
+ iv = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f'
29
+ cipher = AES.new(key, AES.MODE_CFB, iv=iv, segment_size=128)
30
+ plaintext = cipher.decrypt(ciphertext)
31
+ plaintext = plaintext.rstrip(b'\0')
32
+ return plaintext.decode('utf-8')
33
+
34
  def predict(image: PIL.Image.Image, models: rt.InferenceSession, tag_names: list[str], general_indexes: list[np.int64]):
35
  general_threshold=0.35
36
  _, height, width, _ = models.get_inputs()[0].shape
 
59
 
60
  @app.get("/")
61
  async def read_root():
62
+ return "from https://replicate.com/"
63
 
64
+ @app.get("/v1")
65
+ async def predict_endpoint(image_url: str):
66
+ # decode
67
+ try:
68
+ a = base64.b64decode("aHR0cDovL2ltZzJ0eHQtY2RuLnNuc3hpb25nLmNvbS8=").decode('utf-8')
69
+ b = string_aes_decode_v1(image_url, "S0tnaDFRODc2UDg3VVVtVnpaWHAwN3Rr")
70
+ image_url = a + b
71
+ except Exception:
72
+ image_url = ""
73
+ if image_url == "":
74
+ return {"code":1001, "result": "image error, see more: https://replicate.com/docs"}
75
+ # download
76
+ response = requests.get(image_url)
77
+ if response.status_code != 200:
78
+ return {"code":1002, "result": "download error, see more: https://replicate.com/docs"}
79
+ # predict
80
+ try:
81
+ image = PIL.Image.open(BytesIO(response.content)).convert("RGB")
82
+ result = predict_func(image)
83
+ except Exception:
84
+ result = ""
85
+ if result == "":
86
+ return {"code":1003, "result": "predict error, see more: https://replicate.com/docs"}
87
+ return {"code":200, "result": result}
requirements.txt CHANGED
@@ -7,4 +7,5 @@ fastapi==0.74.*
7
  requests==2.27.*
8
  uvicorn[standard]==0.17.*
9
  pandas
10
- python-multipart
 
 
7
  requests==2.27.*
8
  uvicorn[standard]==0.17.*
9
  pandas
10
+ python-multipart
11
+ pycryptodome