Spaces:
Runtime error
Runtime error
import datetime as dt | |
import os | |
from types import SimpleNamespace | |
import pytest | |
from fastapi.testclient import TestClient | |
def is_roughly_now(datetime_str): | |
"""Check if a datetime string is roughly from now""" | |
now = dt.datetime.utcnow() | |
datetime = dt.datetime.fromisoformat(datetime_str) | |
return (now - datetime).total_seconds() < 3 | |
class TestWebservice: | |
def db_file(self, tmp_path): | |
filename = tmp_path / "test-db.sqlite" | |
os.environ["DB_FILE_NAME"] = str(filename) | |
def cursor(self): | |
from gistillery.db import get_db_cursor | |
with get_db_cursor() as cursor: | |
yield cursor | |
def client(self): | |
from gistillery.webservice import app | |
client = TestClient(app) | |
client.get("/clear") | |
return client | |
def mlregistry(self): | |
# use dummy models | |
from gistillery.ml import Summarizer, Tagger | |
from gistillery.preprocessing import RawTextProcessor | |
from gistillery.registry import MlRegistry | |
class DummySummarizer(Summarizer): | |
"""Returns the first 10 characters of the input""" | |
def __init__(self, *args, **kwargs): | |
pass | |
def get_name(self): | |
return "dummy summarizer" | |
def __call__(self, x): | |
return x[:10] | |
class DummyTagger(Tagger): | |
"""Returns the first 3 words of the input""" | |
def __init__(self, *args, **kwargs): | |
pass | |
def get_name(self): | |
return "dummy tagger" | |
def __call__(self, x): | |
return ["#" + word for word in x.split(maxsplit=4)[:3]] | |
registry = MlRegistry() | |
registry.register_processor(RawTextProcessor()) | |
# arguments don't matter for dummy summarizer and tagger | |
summarizer = DummySummarizer(None, None, None, None) | |
registry.register_summarizer(summarizer) | |
tagger = DummyTagger(None, None, None, None) | |
registry.register_tagger(tagger) | |
return registry | |
def process_jobs(self, registry): | |
# emulate work of the background worker | |
from gistillery.worker import check_pending_jobs, process_job | |
jobs = check_pending_jobs() | |
for job in jobs: | |
process_job(job, registry) | |
def test_status(self, client): | |
resp = client.get("/status") | |
assert resp.status_code == 200 | |
assert resp.json() == "OK" | |
def test_recent_empty(self, client): | |
resp = client.get("/recent") | |
assert resp.json() == [] | |
def test_recent_tag_empty(self, client): | |
resp = client.get("/recent/general") | |
assert resp.json() == [] | |
def test_submitted_job_status_pending(self, client, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
resp = client.get(f"/check_job_status/{job_id}") | |
output = resp.json() | |
last_updated = output.pop("last_updated") | |
assert output == { | |
"id": job_id, | |
"status": "pending", | |
} | |
assert is_roughly_now(last_updated) | |
def test_submitted_job_status_not_found(self, client, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
other_job_id = "def5678" | |
resp = client.get(f"/check_job_status/{other_job_id}") | |
output = resp.json() | |
last_updated = output.pop("last_updated") | |
assert output == { | |
"id": other_job_id, | |
"status": "not found", | |
} | |
assert last_updated is None | |
def test_submitted_job_failed(self, client, mlregistry, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
# patch gistillery.worker._process_job to raise an exception | |
def raise_(ex): | |
raise ex | |
# make the job processing fail | |
monkeypatch.setattr( | |
"gistillery.worker._process_job", | |
lambda job, registry: raise_(RuntimeError("something went wrong")), | |
) | |
self.process_jobs(mlregistry) | |
resp = client.get(f"/check_job_status/{job_id}") | |
output = resp.json() | |
output.pop("last_updated") | |
assert output == { | |
"id": job_id, | |
"status": "failed", | |
} | |
def test_submitted_job_status_done(self, client, mlregistry, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
self.process_jobs(mlregistry) | |
resp = client.get(f"/check_job_status/{job_id}") | |
output = resp.json() | |
last_updated = output.pop("last_updated") | |
assert output == { | |
"id": job_id, | |
"status": "done", | |
} | |
assert is_roughly_now(last_updated) | |
def test_recent_with_entries(self, client, mlregistry): | |
# submit 2 entries | |
client.post( | |
"/submit", json={"author": "maxi", "content": "this is a first test"} | |
) | |
client.post( | |
"/submit", | |
json={"author": "mini", "content": "this would be something else"}, | |
) | |
self.process_jobs(mlregistry) | |
resp = client.get("/recent").json() | |
# results are sorted by recency but since dummy models are so fast, the | |
# date in the db could be the same, so we sort by author | |
resp = sorted(resp, key=lambda x: x["author"]) | |
assert len(resp) == 2 | |
resp0 = resp[0] | |
assert resp0["author"] == "maxi" | |
assert resp0["summary"] == "this is a " | |
assert resp0["tags"] == sorted(["#this", "#is", "#a"]) | |
resp1 = resp[1] | |
assert resp1["author"] == "mini" | |
assert resp1["summary"] == "this would" | |
assert resp1["tags"] == sorted(["#this", "#would", "#be"]) | |
def test_recent_tag_with_entries(self, client, mlregistry): | |
# submit 2 entries | |
client.post( | |
"/submit", json={"author": "maxi", "content": "this is a first test"} | |
) | |
client.post( | |
"/submit", | |
json={"author": "mini", "content": "this would be something else"}, | |
) | |
self.process_jobs(mlregistry) | |
# the "this" tag is in both entries | |
resp = client.get("/recent/this").json() | |
assert len(resp) == 2 | |
# the "would" tag is in only one entry | |
resp = client.get("/recent/would").json() | |
assert len(resp) == 1 | |
resp0 = resp[0] | |
assert resp0["author"] == "mini" | |
assert resp0["summary"] == "this would" | |
assert resp0["tags"] == sorted(["#this", "#would", "#be"]) | |
def test_clear(self, client, cursor, mlregistry): | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
self.process_jobs(mlregistry) | |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1 | |
client.get("/clear") | |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0 | |
def test_inputs_stored(self, client, cursor, mlregistry): | |
client.post("/submit", json={"author": "ben", "content": " this is a test\n"}) | |
self.process_jobs(mlregistry) | |
rows = cursor.execute("SELECT * FROM inputs").fetchall() | |
assert len(rows) == 1 | |
assert rows[0].input == "this is a test" | |
def test_submit_url(self, client, cursor, mlregistry, monkeypatch): | |
class MockClient: | |
"""Mock httpx Client, return www.example.com content""" | |
def get(self, url): | |
return SimpleNamespace( | |
text=''' <!doctype html>\n<html>\n<head>\n <title>Example | |
Domain</title>\n\n <meta charset="utf-8" />\n <meta | |
http-equiv="Content-type" content="text/html; charset=utf-8" | |
/>\n <meta name="viewport" content="width=device-width, | |
initial-scale=1" />\n <style type="text/css">\n body {\n | |
background-color: #f0f0f2;\n margin: 0;\n padding: 0;\n | |
font-family: -apple-system, system-ui, BlinkMacSystemFont, | |
"Segoe UI", "Open Sans", "Helvetica Neue", Helvetica, Arial, | |
sans-serif;\n \n }\n div {\n width: 600px;\n margin: 5em | |
auto;\n padding: 2em;\n background-color: #fdfdff;\n | |
border-radius: 0.5em;\n box-shadow: 2px 3px 7px 2px | |
rgba(0,0,0,0.02);\n }\n a:link, a:visited {\n color: | |
#38488f;\n text-decoration: none;\n }\n @media (max-width: | |
700px) {\n div {\n margin: 0 auto;\n width: auto;\n }\n }\n | |
</style> \n</head>\n\n<body>\n<div>\n <h1>Example | |
Domain</h1>\n <p>This domain is for use in illustrative | |
examples in documents. You may use this\n domain in | |
literature without prior coordination or asking for | |
permission.</p>\n <p><a | |
href="https://www.iana.org/domains/example">More | |
information...</a></p>\n</div>\n</body>\n</html>\n''' | |
) | |
monkeypatch.setattr("gistillery.preprocessing.Client", MockClient) | |
from gistillery.preprocessing import DefaultUrlProcessor | |
# register url processor, put it before the default processor | |
mlregistry.register_processor(DefaultUrlProcessor(), last=False) | |
client.post( | |
"/submit", | |
json={ | |
"author": "ben", | |
"content": "https://en.wikipedia.org/wiki/non-existing-page", | |
}, | |
) | |
self.process_jobs(mlregistry) | |
rows = cursor.execute("SELECT * FROM inputs").fetchall() | |
assert len(rows) == 1 | |
expected = "\n".join( | |
[ | |
'https://en.wikipedia.org/wiki/non-existing-page', | |
'', | |
'This domain is for use in illustrative', | |
'examples in documents. You may use this', | |
'domain in', | |
'literature without prior coordination or asking for', | |
'permission.', | |
'More', | |
'information...', | |
] | |
) | |
assert rows[0].input == expected | |