Spaces:
Sleeping
Sleeping
molokhovdmitry
commited on
Commit
·
6a63889
1
Parent(s):
b8147d3
Create test_main.py
Browse files- .github/workflows/python-app.yml +2 -0
- main.py +2 -2
- requirements.txt +2 -0
- test_main.py +27 -0
.github/workflows/python-app.yml
CHANGED
@@ -28,6 +28,8 @@ jobs:
|
|
28 |
python -m pip install --upgrade pip
|
29 |
pip install flake8 pytest
|
30 |
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
|
|
|
|
31 |
- name: Lint with flake8
|
32 |
run: |
|
33 |
# stop the build if there are Python syntax errors or undefined names
|
|
|
28 |
python -m pip install --upgrade pip
|
29 |
pip install flake8 pytest
|
30 |
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
31 |
+
# create .env file with an API Key
|
32 |
+
echo "YT_API_KEY=${{ secrets.YT_API_KEY }}" > .env
|
33 |
- name: Lint with flake8
|
34 |
run: |
|
35 |
# stop the build if there are Python syntax errors or undefined names
|
main.py
CHANGED
@@ -8,8 +8,8 @@ from models import init_emotions_model
|
|
8 |
|
9 |
class Settings(BaseSettings):
|
10 |
YT_API_KEY: str
|
11 |
-
PRED_BATCH_SIZE: int
|
12 |
-
MAX_COMMENT_SIZE: int
|
13 |
model_config = SettingsConfigDict(env_file='.env')
|
14 |
|
15 |
|
|
|
8 |
|
9 |
class Settings(BaseSettings):
|
10 |
YT_API_KEY: str
|
11 |
+
PRED_BATCH_SIZE: int = 512
|
12 |
+
MAX_COMMENT_SIZE: int = 300
|
13 |
model_config = SettingsConfigDict(env_file='.env')
|
14 |
|
15 |
|
requirements.txt
CHANGED
@@ -7,3 +7,5 @@ torchvision
|
|
7 |
torchaudio
|
8 |
transformers
|
9 |
pandas
|
|
|
|
|
|
7 |
torchaudio
|
8 |
transformers
|
9 |
pandas
|
10 |
+
pytest
|
11 |
+
httpx
|
test_main.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.testclient import TestClient
|
2 |
+
from main import app
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
client = TestClient(app)
|
7 |
+
|
8 |
+
|
9 |
+
def test_home():
|
10 |
+
"""Test home page."""
|
11 |
+
response = client.get("/")
|
12 |
+
assert response.status_code == 200
|
13 |
+
|
14 |
+
|
15 |
+
def test_predict():
|
16 |
+
"""Test predict method on an example video."""
|
17 |
+
TEST_VIDEO_ID = "0peXnOnDgQ8"
|
18 |
+
response = client.get(
|
19 |
+
"/predict/",
|
20 |
+
params={"video_id": TEST_VIDEO_ID}
|
21 |
+
)
|
22 |
+
df = pd.read_json(response, orient='records')
|
23 |
+
|
24 |
+
# Ensure the DataFrame has the right amount of columns
|
25 |
+
assert df.shape[1] == 39
|
26 |
+
# Ensure there are no NaN values
|
27 |
+
assert df.isna().sum().sum() == 0
|