Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI,Request | |
| import os | |
| from Crossword_inf import Crossword | |
| from BPSolver_inf import BPSolver | |
| from Strict_json import json_CA_json_converter | |
| import asyncio | |
| from fastapi.middleware.cors import CORSMiddleware | |
| MODEL_CONFIG = { | |
| 'bert': | |
| { | |
| 'MODEL_PATH' : "./Inference_components/dpr_biencoder_trained_EPOCH_2_COMPLETE.bin", | |
| 'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv", | |
| 'DENSE_EMBD_PATH': "./Inference_components/embeddings_BERT_EPOCH_2_COMPLETE0.pkl" | |
| }, | |
| 'distilbert': | |
| { | |
| 'MODEL_PATH': "./Inference_components/distilbert_EPOCHs_7_COMPLETE.bin", | |
| 'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv", | |
| 'DENSE_EMBD_PATH': "./Inference_components/distilbert_7_epochs_embeddings.pkl" | |
| }, | |
| 't5_small': | |
| { | |
| 'MODEL_PATH': './Inference_components/t5_small_new_dataset_2EPOCHS/' | |
| } | |
| } | |
| choosen_model_path = MODEL_CONFIG['distilbert']['MODEL_PATH'] | |
| ans_list_path = MODEL_CONFIG['distilbert']['ANS_TSV_PATH'] | |
| dense_embedding_path = MODEL_CONFIG['distilbert']['DENSE_EMBD_PATH'] | |
| second_pass_model_path = MODEL_CONFIG['t5_small']['MODEL_PATH'] | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| allow_credentials=True, | |
| ) | |
| async def solve_puzzle(json): | |
| try: | |
| puzzle = json_CA_json_converter(json, False) | |
| crossword = Crossword(puzzle) | |
| async def solve_async(): | |
| return await asyncio.to_thread(BPSolver, crossword, | |
| model_path = choosen_model_path, | |
| ans_tsv_path = ans_list_path, | |
| dense_embd_path = dense_embedding_path, | |
| reranker_path = second_pass_model_path, | |
| max_candidates = 40000, | |
| model_type = 'distilbert') | |
| solver = await solve_async() | |
| async def solve_method_async(): | |
| return await asyncio.to_thread(solver.solve,num_iters=60, iterative_improvement_steps=0) | |
| solution = await solve_method_async() | |
| evaluation1 = await asyncio.to_thread(solver.evaluate1, solution['first pass model']['grid']) | |
| evaluation2 = await asyncio.to_thread(solver.evaluate1, solution['second pass model']['final grid']) | |
| return solution['first pass model']['grid'], evaluation1, solution['second pass model']['final grid'], evaluation2 | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return None, None, None | |
| fifo_queue = asyncio.Queue() | |
| jobs = {} | |
| async def worker(): | |
| while True: | |
| print(f"Worker got a job: (size of remaining queue: {fifo_queue.qsize()})") | |
| job_id, job, args, future = await fifo_queue.get() | |
| jobs[job_id]["status"] = "processing" | |
| result = await job(*args) | |
| print(result) | |
| jobs[job_id]["result"] = result | |
| jobs[job_id]["status"] = "completed" if result[1] else "failed" | |
| future.set_result(job_id) | |
| async def on_start_up(): | |
| asyncio.create_task(worker()) | |
| async def solve(request: Request): | |
| json = await request.json() | |
| future = asyncio.Future() | |
| job_id = id(future) | |
| jobs[job_id]= {"status":"queued"} | |
| await fifo_queue.put((job_id, solve_puzzle, [json], future)) | |
| return {"job_id": job_id} | |
| async def get_result(job_id: int): | |
| if job_id in jobs: | |
| returnVal = {} | |
| returnVal = {**jobs[job_id]} | |
| if(jobs[job_id]["status"]=="queued"): | |
| queue_size = fifo_queue.qsize() | |
| for index, (queued_job_id, _, _, _) in enumerate(fifo_queue._queue): | |
| if job_id == queued_job_id: | |
| returnVal["queue_status"] = {"index" : index + 1 , "length": queue_size} | |
| return returnVal | |
| return {"error": "Job not found or completed"} | |
| async def home(): | |
| return { | |
| "Success" : "True", | |
| "Message" : "Pong" | |
| } | |