Update app.py
Browse files
app.py
CHANGED
@@ -43,13 +43,13 @@ os.system("wget https://github.com/onnx/models/raw/main/vision/classification/zf
|
|
43 |
|
44 |
ort_session = ort.InferenceSession("zfnet512-12.onnx")
|
45 |
|
46 |
-
|
47 |
def predict(path):
|
48 |
img_batch = preprocess(get_image(path))
|
49 |
|
50 |
outputs = ort_session.run(
|
51 |
None,
|
52 |
-
{"data_0": img_batch.astype(np.float32)},
|
53 |
)
|
54 |
|
55 |
a = np.argsort(-outputs[0].flatten())
|
|
|
43 |
|
44 |
ort_session = ort.InferenceSession("zfnet512-12.onnx")
|
45 |
|
46 |
+
|
47 |
def predict(path):
|
48 |
img_batch = preprocess(get_image(path))
|
49 |
|
50 |
outputs = ort_session.run(
|
51 |
None,
|
52 |
+
{"gpu_0/data_0": img_batch.astype(np.float32)},
|
53 |
)
|
54 |
|
55 |
a = np.argsort(-outputs[0].flatten())
|