Spaces:
Running
Running
from fastapi import FastAPI,Request | |
import os | |
from Crossword_inf import Crossword | |
from BPSolver_inf import BPSolver | |
from Strict_json import json_CA_json_converter | |
import asyncio | |
from fastapi.middleware.cors import CORSMiddleware | |
MODEL_CONFIG = { | |
'bert': | |
{ | |
'MODEL_PATH' : "./Inference_components/dpr_biencoder_trained_EPOCH_2_COMPLETE.bin", | |
'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv", | |
'DENSE_EMBD_PATH': "./Inference_components/embeddings_BERT_EPOCH_2_COMPLETE0.pkl" | |
}, | |
'distilbert': | |
{ | |
'MODEL_PATH': "./Inference_components/distilbert_EPOCHs_7_COMPLETE.bin", | |
'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv", | |
'DENSE_EMBD_PATH': "./Inference_components/distilbert_7_epochs_embeddings.pkl" | |
}, | |
't5_small': | |
{ | |
'MODEL_PATH': './Inference_components/t5_small_new_dataset_2EPOCHS/' | |
} | |
} | |
choosen_model_path = MODEL_CONFIG['distilbert']['MODEL_PATH'] | |
ans_list_path = MODEL_CONFIG['distilbert']['ANS_TSV_PATH'] | |
dense_embedding_path = MODEL_CONFIG['distilbert']['DENSE_EMBD_PATH'] | |
second_pass_model_path = MODEL_CONFIG['t5_small']['MODEL_PATH'] | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
allow_credentials=True, | |
) | |
async def solve_puzzle(json): | |
try: | |
puzzle = json_CA_json_converter(json, False) | |
crossword = Crossword(puzzle) | |
async def solve_async(): | |
return await asyncio.to_thread(BPSolver, crossword, | |
model_path = choosen_model_path, | |
ans_tsv_path = ans_list_path, | |
dense_embd_path = dense_embedding_path, | |
reranker_path = second_pass_model_path, | |
max_candidates = 40000, | |
model_type = 'distilbert') | |
solver = await solve_async() | |
async def solve_method_async(): | |
return await asyncio.to_thread(solver.solve,num_iters=60, iterative_improvement_steps=0) | |
solution = await solve_method_async() | |
evaluation1 = await asyncio.to_thread(solver.evaluate1, solution['first pass model']['grid']) | |
evaluation2 = await asyncio.to_thread(solver.evaluate1, solution['second pass model']['final grid']) | |
return solution['first pass model']['grid'], evaluation1, solution['second pass model']['final grid'], evaluation2 | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return None, None, None | |
fifo_queue = asyncio.Queue() | |
jobs = {} | |
async def worker(): | |
while True: | |
print(f"Worker got a job: (size of remaining queue: {fifo_queue.qsize()})") | |
job_id, job, args, future = await fifo_queue.get() | |
jobs[job_id]["status"] = "processing" | |
result = await job(*args) | |
print(result) | |
jobs[job_id]["result"] = result | |
jobs[job_id]["status"] = "completed" if result[1] else "failed" | |
future.set_result(job_id) | |
async def on_start_up(): | |
asyncio.create_task(worker()) | |
async def solve(request: Request): | |
json = await request.json() | |
future = asyncio.Future() | |
job_id = id(future) | |
jobs[job_id]= {"status":"queued"} | |
await fifo_queue.put((job_id, solve_puzzle, [json], future)) | |
return {"job_id": job_id} | |
async def get_result(job_id: int): | |
if job_id in jobs: | |
returnVal = {} | |
returnVal = {**jobs[job_id]} | |
if(jobs[job_id]["status"]=="queued"): | |
queue_size = fifo_queue.qsize() | |
for index, (queued_job_id, _, _, _) in enumerate(fifo_queue._queue): | |
if job_id == queued_job_id: | |
returnVal["queue_status"] = {"index" : index + 1 , "length": queue_size} | |
return returnVal | |
return {"error": "Job not found or completed"} | |
async def home(): | |
return { | |
"Success" : "True", | |
"Message" : "Pong" | |
} | |