File size: 2,032 Bytes
fdc1efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import io
import sys
from pathlib import Path

import soundfile as sf
from fastapi.testclient import TestClient

sys.path.append(".")

from src.api.main import app  # noqa

TEST_FILES_DIR = Path(__file__).parent / "test_files"
TEST_WAV_FILE = TEST_FILES_DIR / "test.wav"

client = TestClient(app)


def test_health_check():
    response = client.get("/health-check")
    assert response.status_code == 200
    assert response.json() == {"status": "API is running"}


def test_predict_valid_cut_file():
    audio_data, sample_rate = sf.read(TEST_WAV_FILE)
    audio_file = io.BytesIO()
    sf.write(audio_file, audio_data, sample_rate, format="wav")
    audio_file = ("test.wav", audio_file)

    file = {"file": audio_file}
    request_data = {"model_name": "Accuracy"}
    # Make a request to the /predict endpoint
    response = client.post("/predict", params=request_data, files=file)

    # Check that the response is successful
    assert response.status_code == 200
    assert response.json()["prediction"]["test.wav"] is not None


def test_predict_valid_file():
    with open(TEST_WAV_FILE, "rb") as file:
        data = {"model_name": "Accuracy"}
        response = client.post("/predict", params=data, files={"file": file})
        assert response.status_code == 200
        assert response.json()["prediction"]["test.wav"] is not None


def test_predict_invalid_file_type():
    file_data = io.BytesIO(b"dummy txt data")
    file = ("test.txt", file_data)
    data = {"model_name": "Accuracy"}
    response = client.post("/predict", params=data, files={"file": file})
    assert response.status_code == 400
    assert "Only wav files are supported" in response.json()["detail"]


def test_predict_invalid_model():
    file_data = io.BytesIO(b"dummy wav data")
    file = ("test.wav", file_data)
    data = {"model_name": "InvalidModel"}
    response = client.post("/predict", params=data, files={"file": file})
    assert response.status_code == 400
    assert "Selected model doesn't exist" in response.json()["detail"]