pragmatic-ft-listener / handler.py
fried-nlp's picture
Fix escaping
1b4093a
raw
history blame
2.11 kB
from transformers import pipeline
from greenery import parse
from greenery.parse import NoMatch
from listener import Listener, ListenerOutput
import time
import json
import torch
class EndpointHandler:
def __init__(self, path=""):
self.listener = Listener(path, {
"do_sample": True,
"max_new_tokens": 128,
"top_p": 0.9,
"num_return_sequences": 500,
"num_beams": 1
}, device="cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, data):
# get inputs
inp = data.pop("inputs", None)
spec = inp["spec"]
true_program = inp["true_program"]
start = time.time()
outputs = self.listener.synthesize([[(s["string"], s["label"]) for s in spec]], return_scores=True)
consistent_program_scores = [outputs.decoded_scores[0][i] for i in outputs.idx[0]]
consistent_programs = [outputs.decoded[0][i] for i in outputs.idx[0]]
sorted_programs = sorted(set(zip(consistent_program_scores, consistent_programs)), reverse=True, key=lambda x: x[0])
end = time.time()
top_guess = None
top_score = None
top_success = False
top_10_guesses = None
top_10_scores = None
top_10_success = False
if len(sorted_programs) > 0:
top_guess = sorted_programs[0][1]
top_score = sorted_programs[0][0]
top_success = parse(top_guess.replace('\\', '')).equivalent(parse(true_program).replace('\\', ''))
top_10_guesses = [p for s, p in sorted_programs[:10]]
top_10_scores = [s for s, p in sorted_programs[:10]]
top_10_success = any([parse(p.replace('\\', '')).equivalent(parse(true_program.replace('\\', ''))) for p in top_10_guesses])
return {
"guess": top_guess,
"top_1_success": top_success,
"top_1_score": top_score,
"top_10_guesses": top_10_guesses,
"top_10_scores": top_10_scores,
"top_10_success": top_10_success,
"time": end - start
}