gistillery / src /worker.py
Benjamin Bossan
Refactor ml model handling
126a4c6
raw
history blame
4.79 kB
import time
from dataclasses import dataclass
from base import JobInput
from db import get_db_cursor
from ml import (
DefaultUrlProcessor,
HfTransformersSummarizer,
HfTransformersTagger,
MlRegistry,
RawTextProcessor,
)
SLEEP_INTERVAL = 5
def check_pending_jobs() -> list[JobInput]:
"""Check DB for pending jobs"""
with get_db_cursor() as cursor:
# fetch pending jobs, join authro and content from entries table
query = """
SELECT j.entry_id, e.author, e.source
FROM jobs j
JOIN entries e
ON j.entry_id = e.id
WHERE j.status = 'pending'
"""
res = list(cursor.execute(query))
return [
JobInput(id=_id, author=author, content=content) for _id, author, content in res
]
@dataclass
class JobOutput:
summary: str
tags: list[str]
processor_name: str
summarizer_name: str
tagger_name: str
def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
processor = registry.get_processor(job)
processor_name = processor.get_name()
processed = processor(job)
tagger = registry.get_tagger()
tagger_name = tagger.get_name()
tags = tagger(processed)
summarizer = registry.get_summarizer()
summarizer_name = summarizer.get_name()
summary = summarizer(processed)
return JobOutput(
summary=summary,
tags=tags,
processor_name=processor_name,
summarizer_name=summarizer_name,
tagger_name=tagger_name,
)
def store(job: JobInput, output: JobOutput) -> None:
with get_db_cursor() as cursor:
# write to entries, summary, tags tables
cursor.execute(
(
"INSERT INTO summaries (entry_id, summary, summarizer_name)"
" VALUES (?, ?, ?)"
),
(job.id, output.summary, output.summarizer_name),
)
cursor.executemany(
"INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
[(job.id, tag, output.tagger_name) for tag in output.tags],
)
def process_job(job: JobInput, registry: MlRegistry) -> None:
tic = time.perf_counter()
print(f"Processing job for (id={job.id[:8]})")
# care: acquire cursor (which leads to locking) as late as possible, since
# the processing and we don't want to block other workers during that time
try:
output = _process_job(job, registry)
store(job, output)
# update job status to done
with get_db_cursor() as cursor:
cursor.execute(
"UPDATE jobs SET status = 'done' WHERE entry_id = ?", (job.id,)
)
except Exception as e:
# update job status to failed
with get_db_cursor() as cursor:
cursor.execute(
"UPDATE jobs SET status = 'failed' WHERE entry_id = ?", (job.id,)
)
print(f"Failed to process job for (id={job.id[:8]}): {e}")
toc = time.perf_counter()
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
def load_mlregistry(model_name: str) -> MlRegistry:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config_summarizer = GenerationConfig.from_pretrained(model_name)
config_summarizer.max_new_tokens = 200
config_summarizer.min_new_tokens = 100
config_summarizer.top_k = 5
config_summarizer.repetition_penalty = 1.5
config_tagger = GenerationConfig.from_pretrained(model_name)
config_tagger.max_new_tokens = 50
config_tagger.min_new_tokens = 25
# increase the temperature to make the model more creative
config_tagger.temperature = 1.5
summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer)
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
registry = MlRegistry()
registry.register_processor(DefaultUrlProcessor())
registry.register_processor(RawTextProcessor())
registry.register_summarizer(summarizer)
registry.register_tagger(tagger)
return registry
def main() -> None:
model_name = "google/flan-t5-large"
registry = load_mlregistry(model_name)
while True:
jobs = check_pending_jobs()
if not jobs:
print("No pending jobs found, sleeping...")
time.sleep(SLEEP_INTERVAL)
continue
print(f"Found {len(jobs)} pending job(s), processing...")
for job in jobs:
process_job(job, registry)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("Shutting down...")
exit(0)