RTMDet_PRODUCTION / tests /test_image.py
NMPhap
<Test>: Fix test]
1654acc
raw
history blame
No virus
2.9 kB
import os
import mmcv
import numpy as np
from app.routers.image import inferenceImage
import pytest
import pytest
from fastapi.testclient import TestClient
from fastapi.routing import APIRoute
from app.main import app
def endpoints():
endpoints = []
for route in app.routes:
if isinstance(route, APIRoute):
endpoints.append(route.path)
return endpoints
@pytest.fixture
def client():
client = TestClient(app, "http://0.0.0.0:3000")
yield client
@pytest.mark.skipif("/image" not in endpoints(), reason="Route not defined")
class TestImageRoute():
img = mmcv.imread('demo.jpg')
url = "http://0.0.0.0:3000/image"
def test_inferenceImage(self):
bboxes, labels = inferenceImage(mmcv.imread('demo.jpg'), 0.3, True)
assert len(bboxes.tolist()) > 0 and len(labels.tolist()) > 0 and len(bboxes.tolist()) == len(labels.tolist())
result = inferenceImage(self.img, 0.3, False)
assert type(result) is np.ndarray and result.shape == self.img.shape
def test_ImageAPI(self, client):
payload = {}
files=[
('file',('demo.jpg',open('demo.jpg','rb'),'image/jpeg'))
]
headers = {
'accept': 'application/json'
}
response = client.request("POST", "image", headers=headers, data=payload, files=files)
result = mmcv.imfrombytes(response.read())
assert response.status_code == 200 and result.shape == self.img.shape
def test_ImageAPI_one_channel_array(self, client):
np.zeros((1,640,640)).dump("one_channel.jpg")
payload = {}
files=[
('file', ("demo.jpg",open("one_channel.jpg", "rb"),'image/jpeg'))
]
headers = {
'accept': 'application/json'
}
response = client.request("POST", "image", headers=headers, data=payload, files=files)
assert response.status_code != 200
def test_ImageAPIWithThresHold(self, client):
payload = {}
files=[
('file',('demo.jpg',open('demo.jpg','rb'),'image/jpeg'))
]
headers = {
'accept': 'application/json'
}
response = client.request("POST", "image?threshold=1&raw=True", headers=headers, data=payload, files=files)
thresHold = 0.4
assert response.status_code == 200 # The result with threshold equal 0 is 0
# No detected object has 100% accuracy
assert len(response.json()["labels"]) == 0
payload = {}
files=[
('file',('demo.jpg',open('demo.jpg','rb'),'image/jpeg'))
]
headers = {
'accept': 'application/json'
}
response = client.request("POST", "image?threshold=" + str(thresHold) + "&raw=True", headers=headers, data=payload, files=files)
assert response.status_code == 200
for bbox in response.json()['bboxes']:
assert bbox[4] >= thresHold