|
import pytest |
|
from app import app |
|
from helper_functions import predict_class, transform_list_of_texts, prepare_text, inference |
|
import torch |
|
from transformers import DistilBertForSequenceClassification, AutoTokenizer |
|
|
|
@pytest.fixture |
|
def client(): |
|
app.config['TESTING'] = True |
|
with app.test_client() as client: |
|
yield client |
|
|
|
|
|
|
|
def test_predict_class(): |
|
|
|
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased") |
|
text = ["This is a sample text for testing."] |
|
|
|
predicted_class, class_probabilities = predict_class(text, model) |
|
|
|
assert isinstance(predicted_class, tuple) |
|
assert isinstance(class_probabilities, dict) |
|
assert len(class_probabilities) == 17 |
|
|
|
def test_transform_list_of_texts(): |
|
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") |
|
texts = ["This is a sample text.", "Another sample text."] |
|
|
|
result = transform_list_of_texts(texts, tokenizer, 510, 510, 1, 2550) |
|
|
|
assert isinstance(result, dict) |
|
assert "input_ids" in result |
|
assert "attention_mask" in result |
|
|
|
|
|
|
|
def test_pdf_upload(client): |
|
|
|
with open('sample.pdf', 'rb') as pdf_file: |
|
data = {'file': (pdf_file, 'sample.pdf')} |
|
response = client.post('/pdf/upload', data=data, content_type='multipart/form-data') |
|
|
|
assert response.status_code == 200 |
|
assert b'class_probabilities' in response.data |
|
|
|
def test_sentence_endpoint(client): |
|
data = {'text': 'This is a sample sentence for testing.'} |
|
response = client.post('/sentence', data=data) |
|
|
|
assert response.status_code == 200 |
|
assert b'predicted_class' in response.data |
|
|
|
def test_voice_endpoint(client): |
|
|
|
with open('sample_audio.wav', 'rb') as audio_file: |
|
data = {'audio': (audio_file, 'sample_audio.wav')} |
|
response = client.post('/voice', data=data, content_type='multipart/form-data') |
|
|
|
assert response.status_code == 200 |
|
assert b'extracted_text' in response.data |