pfe_site / tests /testapp.py
YsnHdn's picture
fix : Resloving the timeout problem for long sequences
cf75fd2
import pytest
from app import app
from helper_functions import predict_class, transform_list_of_texts, prepare_text, inference
import torch
from transformers import DistilBertForSequenceClassification, AutoTokenizer
@pytest.fixture
def client():
app.config['TESTING'] = True
with app.test_client() as client:
yield client
# Unit tests
def test_predict_class():
# Mock the model and tokenizer
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
text = ["This is a sample text for testing."]
predicted_class, class_probabilities = predict_class(text, model)
assert isinstance(predicted_class, tuple)
assert isinstance(class_probabilities, dict)
assert len(class_probabilities) == 17 # Assuming 17 classes
def test_transform_list_of_texts():
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
texts = ["This is a sample text.", "Another sample text."]
result = transform_list_of_texts(texts, tokenizer, 510, 510, 1, 2550)
assert isinstance(result, dict)
assert "input_ids" in result
assert "attention_mask" in result
# Integration tests
def test_pdf_upload(client):
# You'll need to create a sample PDF file for testing
with open('sample.pdf', 'rb') as pdf_file:
data = {'file': (pdf_file, 'sample.pdf')}
response = client.post('/pdf/upload', data=data, content_type='multipart/form-data')
assert response.status_code == 200
assert b'class_probabilities' in response.data
def test_sentence_endpoint(client):
data = {'text': 'This is a sample sentence for testing.'}
response = client.post('/sentence', data=data)
assert response.status_code == 200
assert b'predicted_class' in response.data
def test_voice_endpoint(client):
# You'll need to create a sample audio file for testing
with open('sample_audio.wav', 'rb') as audio_file:
data = {'audio': (audio_file, 'sample_audio.wav')}
response = client.post('/voice', data=data, content_type='multipart/form-data')
assert response.status_code == 200
assert b'extracted_text' in response.data