Spaces:
Runtime error
Runtime error
| import datetime as dt | |
| import os | |
| from types import SimpleNamespace | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| def is_roughly_now(datetime_str): | |
| """Check if a datetime string is roughly from now""" | |
| now = dt.datetime.utcnow() | |
| datetime = dt.datetime.fromisoformat(datetime_str) | |
| return (now - datetime).total_seconds() < 3 | |
| class TestWebservice: | |
| def db_file(self, tmp_path): | |
| filename = tmp_path / "test-db.sqlite" | |
| os.environ["DB_FILE_NAME"] = str(filename) | |
| def client(self): | |
| from gistillery.webservice import app | |
| client = TestClient(app) | |
| client.get("/clear") | |
| return client | |
| def mlregistry(self): | |
| # use dummy models | |
| from gistillery.ml import MlRegistry, RawTextProcessor, Summarizer, Tagger | |
| class DummySummarizer(Summarizer): | |
| """Returns the first 10 characters of the input""" | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def get_name(self): | |
| return "dummy summarizer" | |
| def __call__(self, x): | |
| return x[:10] | |
| class DummyTagger(Tagger): | |
| """Returns the first 3 words of the input""" | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def get_name(self): | |
| return "dummy tagger" | |
| def __call__(self, x): | |
| return ["#" + word for word in x.split(maxsplit=4)[:3]] | |
| registry = MlRegistry() | |
| registry.register_processor(RawTextProcessor()) | |
| # arguments don't matter for dummy summarizer and tagger | |
| summarizer = DummySummarizer(None, None, None, None) | |
| registry.register_summarizer(summarizer) | |
| tagger = DummyTagger(None, None, None, None) | |
| registry.register_tagger(tagger) | |
| return registry | |
| def process_jobs(self, registry): | |
| # emulate work of the background worker | |
| from gistillery.worker import check_pending_jobs, process_job | |
| jobs = check_pending_jobs() | |
| for job in jobs: | |
| process_job(job, registry) | |
| def test_status(self, client): | |
| resp = client.get("/status") | |
| assert resp.status_code == 200 | |
| assert resp.json() == "OK" | |
| def test_recent_empty(self, client): | |
| resp = client.get("/recent") | |
| assert resp.json() == [] | |
| def test_recent_tag_empty(self, client): | |
| resp = client.get("/recent/general") | |
| assert resp.json() == [] | |
| def test_submitted_job_status_pending(self, client, monkeypatch): | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| resp = client.get(f"/check_job_status/{job_id}") | |
| output = resp.json() | |
| last_updated = output.pop("last_updated") | |
| assert output == { | |
| "id": job_id, | |
| "status": "pending", | |
| } | |
| assert is_roughly_now(last_updated) | |
| def test_submitted_job_status_not_found(self, client, monkeypatch): | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| other_job_id = "def5678" | |
| resp = client.get(f"/check_job_status/{other_job_id}") | |
| output = resp.json() | |
| last_updated = output.pop("last_updated") | |
| assert output == { | |
| "id": other_job_id, | |
| "status": "not found", | |
| } | |
| assert last_updated is None | |
| def test_submitted_job_status_done(self, client, mlregistry, monkeypatch): | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| self.process_jobs(mlregistry) | |
| resp = client.get(f"/check_job_status/{job_id}") | |
| output = resp.json() | |
| last_updated = output.pop("last_updated") | |
| assert output == { | |
| "id": job_id, | |
| "status": "done", | |
| } | |
| assert is_roughly_now(last_updated) | |
| def test_recent_with_entries(self, client, mlregistry): | |
| # submit 2 entries | |
| client.post( | |
| "/submit", json={"author": "maxi", "content": "this is a first test"} | |
| ) | |
| client.post( | |
| "/submit", json={"author": "mini", "content": "this would be something else"} | |
| ) | |
| self.process_jobs(mlregistry) | |
| resp = client.get("/recent").json() | |
| # results are sorted by recency but since dummy models are so fast, the | |
| # date in the db could be the same, so we sort by author | |
| resp = sorted(resp, key=lambda x: x["author"]) | |
| assert len(resp) == 2 | |
| resp0 = resp[0] | |
| assert resp0["author"] == "maxi" | |
| assert resp0["summary"] == "this is a " | |
| assert resp0["tags"] == sorted(["#this", "#is", "#a"]) | |
| resp1 = resp[1] | |
| assert resp1["author"] == "mini" | |
| assert resp1["summary"] == "this would" | |
| assert resp1["tags"] == sorted(["#this", "#would", "#be"]) | |
| def test_recent_tag_with_entries(self, client, mlregistry): | |
| # submit 2 entries | |
| client.post( | |
| "/submit", json={"author": "maxi", "content": "this is a first test"} | |
| ) | |
| client.post( | |
| "/submit", json={"author": "mini", "content": "this would be something else"} | |
| ) | |
| self.process_jobs(mlregistry) | |
| # the "this" tag is in both entries | |
| resp = client.get("/recent/this").json() | |
| assert len(resp) == 2 | |
| # the "would" tag is in only one entry | |
| resp = client.get("/recent/would").json() | |
| assert len(resp) == 1 | |
| resp0 = resp[0] | |
| assert resp0["author"] == "mini" | |
| assert resp0["summary"] == "this would" | |
| assert resp0["tags"] == sorted(["#this", "#would", "#be"]) | |
| def test_submitted_job_failed(self, client, mlregistry, monkeypatch): | |
| # monkeypatch uuid4 to return a known value | |
| job_id = "abc1234" | |
| monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
| client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
| # patch gistillery.worker._process_job to raise an exception | |
| def raise_(ex): | |
| raise ex | |
| monkeypatch.setattr( | |
| "gistillery.worker._process_job", | |
| lambda job, registry: raise_(RuntimeError("something went wrong")), | |
| ) | |
| self.process_jobs(mlregistry) | |
| resp = client.get(f"/check_job_status/{job_id}") | |
| output = resp.json() | |
| output.pop("last_updated") | |
| assert output == { | |
| "id": job_id, | |
| "status": "failed", | |
| } | |