File size: 4,004 Bytes
c68e701
a04b340
 
 
 
 
c68e701
22f665a
a04b340
076da67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a04b340
acb82e9
 
22f665a
 
 
 
 
 
 
1f5168b
c68e701
076da67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ed164
1f5168b
076da67
 
228f684
 
076da67
23e5d06
c68e701
076da67
 
 
1f5168b
 
c68e701
1f5168b
c68e701
1f5168b
c68e701
 
 
 
 
 
076da67
c68e701
076da67
c68e701
1f5168b
c68e701
 
 
1f5168b
c68e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9eee843
c68e701
 
1f5168b
c68e701
 
 
 
 
 
1f5168b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fastapi import FastAPI,Request
import os
from Crossword_inf import Crossword
from BPSolver_inf import BPSolver
from Strict_json import json_CA_json_converter

import asyncio
from fastapi.middleware.cors import CORSMiddleware

MODEL_CONFIG = {
	'bert': 
	{
		'MODEL_PATH' : "./Inference_components/dpr_biencoder_trained_EPOCH_2_COMPLETE.bin",
		'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
		'DENSE_EMBD_PATH': "./Inference_components/embeddings_BERT_EPOCH_2_COMPLETE0.pkl"
	},
	'distilbert': 
	{
		'MODEL_PATH': "./Inference_components/distilbert_EPOCHs_7_COMPLETE.bin", 
		'ANS_TSV_PATH': "./Inference_components/all_answer_list.tsv",
		'DENSE_EMBD_PATH': "./Inference_components/distilbert_7_epochs_embeddings.pkl"
	},
	't5_small':
	{
		'MODEL_PATH': './Inference_components/t5_small_new_dataset_2EPOCHS/'
	}
}

choosen_model_path = MODEL_CONFIG['distilbert']['MODEL_PATH']
ans_list_path = MODEL_CONFIG['distilbert']['ANS_TSV_PATH']
dense_embedding_path = MODEL_CONFIG['distilbert']['DENSE_EMBD_PATH']
second_pass_model_path = MODEL_CONFIG['t5_small']['MODEL_PATH']


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
    allow_credentials=True,
)
   
async def solve_puzzle(json):
    try:
        
        puzzle = json_CA_json_converter(json, False)
        crossword = Crossword(puzzle)
        
        async def solve_async():
            return await asyncio.to_thread(BPSolver, crossword, 
                        model_path = choosen_model_path, 
                        ans_tsv_path = ans_list_path, 
                        dense_embd_path = dense_embedding_path,
                        reranker_path = second_pass_model_path, 
                        max_candidates = 40000, 
                        model_type = 'distilbert')

        solver = await solve_async()
        
        async def solve_method_async():
            return await asyncio.to_thread(solver.solve,num_iters=60, iterative_improvement_steps=0)

        solution = await solve_method_async()
        
        evaluation1 = await asyncio.to_thread(solver.evaluate1, solution['first pass model']['grid'])
        evaluation2 = await asyncio.to_thread(solver.evaluate1, solution['second pass model']['final grid'])
        
        return solution['first pass model']['grid'], evaluation1, solution['second pass model']['final grid'], evaluation2
    
    except Exception as e:
        print(f"An error occurred: {e}")
        return None, None, None 


fifo_queue = asyncio.Queue()

jobs = {}

async def worker():
    while True:
        print(f"Worker got a job: (size of remaining queue: {fifo_queue.qsize()})")
        job_id, job, args, future = await fifo_queue.get()
        jobs[job_id]["status"] = "processing"
        result = await job(*args)
        print(result)
        jobs[job_id]["result"] = result
        jobs[job_id]["status"] = "completed" if result[1] else "failed"
        future.set_result(job_id)

@app.on_event("startup")
async def on_start_up():
    asyncio.create_task(worker())

@app.post("/solve")
async def solve(request: Request):
   json = await request.json()
   future = asyncio.Future()
   job_id = id(future)
   jobs[job_id]= {"status":"queued"}
   await fifo_queue.put((job_id, solve_puzzle, [json], future))
   return {"job_id": job_id} 

@app.get("/result/{job_id}")
async def get_result(job_id: int):
    if job_id in jobs:
        returnVal = {}
        returnVal = {**jobs[job_id]}
        if(jobs[job_id]["status"]=="queued"):
            queue_size = fifo_queue.qsize()
            for index, (queued_job_id, _, _, _) in enumerate(fifo_queue._queue):
                if job_id == queued_job_id:
                    returnVal["queue_status"] = {"index" : index + 1 , "length": queue_size}
        return returnVal
    return {"error": "Job not found or completed"}

@app.get("/")
async def home():
    return {
        "Success" : "True",
        "Message" : "Pong"
    }