import time from dataclasses import dataclass from base import JobInput from db import get_db_cursor from ml import ( DefaultUrlProcessor, HfTransformersSummarizer, HfTransformersTagger, MlRegistry, RawTextProcessor, ) SLEEP_INTERVAL = 5 def check_pending_jobs() -> list[JobInput]: """Check DB for pending jobs""" with get_db_cursor() as cursor: # fetch pending jobs, join authro and content from entries table query = """ SELECT j.entry_id, e.author, e.source FROM jobs j JOIN entries e ON j.entry_id = e.id WHERE j.status = 'pending' """ res = list(cursor.execute(query)) return [ JobInput(id=_id, author=author, content=content) for _id, author, content in res ] @dataclass class JobOutput: summary: str tags: list[str] processor_name: str summarizer_name: str tagger_name: str def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput: processor = registry.get_processor(job) processor_name = processor.get_name() processed = processor(job) tagger = registry.get_tagger() tagger_name = tagger.get_name() tags = tagger(processed) summarizer = registry.get_summarizer() summarizer_name = summarizer.get_name() summary = summarizer(processed) return JobOutput( summary=summary, tags=tags, processor_name=processor_name, summarizer_name=summarizer_name, tagger_name=tagger_name, ) def store(job: JobInput, output: JobOutput) -> None: with get_db_cursor() as cursor: # write to entries, summary, tags tables cursor.execute( ( "INSERT INTO summaries (entry_id, summary, summarizer_name)" " VALUES (?, ?, ?)" ), (job.id, output.summary, output.summarizer_name), ) cursor.executemany( "INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)", [(job.id, tag, output.tagger_name) for tag in output.tags], ) def process_job(job: JobInput, registry: MlRegistry) -> None: tic = time.perf_counter() print(f"Processing job for (id={job.id[:8]})") # care: acquire cursor (which leads to locking) as late as possible, since # the processing and we don't want to block other workers during that time try: output = _process_job(job, registry) store(job, output) # update job status to done with get_db_cursor() as cursor: cursor.execute( "UPDATE jobs SET status = 'done' WHERE entry_id = ?", (job.id,) ) except Exception as e: # update job status to failed with get_db_cursor() as cursor: cursor.execute( "UPDATE jobs SET status = 'failed' WHERE entry_id = ?", (job.id,) ) print(f"Failed to process job for (id={job.id[:8]}): {e}") toc = time.perf_counter() print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds") def load_mlregistry(model_name: str) -> MlRegistry: from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig model = AutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) config_summarizer = GenerationConfig.from_pretrained(model_name) config_summarizer.max_new_tokens = 200 config_summarizer.min_new_tokens = 100 config_summarizer.top_k = 5 config_summarizer.repetition_penalty = 1.5 config_tagger = GenerationConfig.from_pretrained(model_name) config_tagger.max_new_tokens = 50 config_tagger.min_new_tokens = 25 # increase the temperature to make the model more creative config_tagger.temperature = 1.5 summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer) tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger) registry = MlRegistry() registry.register_processor(DefaultUrlProcessor()) registry.register_processor(RawTextProcessor()) registry.register_summarizer(summarizer) registry.register_tagger(tagger) return registry def main() -> None: model_name = "google/flan-t5-large" registry = load_mlregistry(model_name) while True: jobs = check_pending_jobs() if not jobs: print("No pending jobs found, sleeping...") time.sleep(SLEEP_INTERVAL) continue print(f"Found {len(jobs)} pending job(s), processing...") for job in jobs: process_job(job, registry) if __name__ == "__main__": try: main() except KeyboardInterrupt: print("Shutting down...") exit(0)