EZ-Crossword / main.py
Ujjwal123's picture
Removed second pass and solved for british crossword
c8ed164
from fastapi import FastAPI,Request
import os
from Crossword_inf import Crossword
from BPSolver_inf import BPSolver
from Strict_json import json_CA_json_converter
import asyncio
from fastapi.middleware.cors import CORSMiddleware
MODEL_CONFIG = {
'bert':
{
'MODEL_PATH' : "./Inference_components/dpr_biencoder_trained_EPOCH_2_COMPLETE.bin",
'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
'DENSE_EMBD_PATH': "./Inference_components/embeddings_BERT_EPOCH_2_COMPLETE0.pkl"
},
'distilbert':
{
'MODEL_PATH': "./Inference_components/distilbert_EPOCHs_7_COMPLETE.bin",
'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
'DENSE_EMBD_PATH': "./Inference_components/distilbert_7_epochs_embeddings.pkl"
},
't5_small':
{
'MODEL_PATH': './Inference_components/t5_small_new_dataset_2EPOCHS/'
}
}
choosen_model_path = MODEL_CONFIG['distilbert']['MODEL_PATH']
ans_list_path = MODEL_CONFIG['distilbert']['ANS_TSV_PATH']
dense_embedding_path = MODEL_CONFIG['distilbert']['DENSE_EMBD_PATH']
second_pass_model_path = MODEL_CONFIG['t5_small']['MODEL_PATH']
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
async def solve_puzzle(json):
try:
puzzle = json_CA_json_converter(json, False)
crossword = Crossword(puzzle)
async def solve_async():
return await asyncio.to_thread(BPSolver, crossword,
model_path = choosen_model_path,
ans_tsv_path = ans_list_path,
dense_embd_path = dense_embedding_path,
reranker_path = second_pass_model_path,
max_candidates = 40000,
model_type = 'distilbert')
solver = await solve_async()
async def solve_method_async():
return await asyncio.to_thread(solver.solve,num_iters=60, iterative_improvement_steps=0)
solution = await solve_method_async()
evaluation1 = await asyncio.to_thread(solver.evaluate1, solution['first pass model']['grid'])
evaluation2 = await asyncio.to_thread(solver.evaluate1, solution['second pass model']['final grid'])
return solution['first pass model']['grid'], evaluation1, solution['second pass model']['final grid'], evaluation2
except Exception as e:
print(f"An error occurred: {e}")
return None, None, None
fifo_queue = asyncio.Queue()
jobs = {}
async def worker():
while True:
print(f"Worker got a job: (size of remaining queue: {fifo_queue.qsize()})")
job_id, job, args, future = await fifo_queue.get()
jobs[job_id]["status"] = "processing"
result = await job(*args)
print(result)
jobs[job_id]["result"] = result
jobs[job_id]["status"] = "completed" if result[1] else "failed"
future.set_result(job_id)
@app.on_event("startup")
async def on_start_up():
asyncio.create_task(worker())
@app.post("/solve")
async def solve(request: Request):
json = await request.json()
future = asyncio.Future()
job_id = id(future)
jobs[job_id]= {"status":"queued"}
await fifo_queue.put((job_id, solve_puzzle, [json], future))
return {"job_id": job_id}
@app.get("/result/{job_id}")
async def get_result(job_id: int):
if job_id in jobs:
returnVal = {}
returnVal = {**jobs[job_id]}
if(jobs[job_id]["status"]=="queued"):
queue_size = fifo_queue.qsize()
for index, (queued_job_id, _, _, _) in enumerate(fifo_queue._queue):
if job_id == queued_job_id:
returnVal["queue_status"] = {"index" : index + 1 , "length": queue_size}
return returnVal
return {"error": "Job not found or completed"}
@app.get("/")
async def home():
return {
"Success" : "True",
"Message" : "Pong"
}