from transformers import pipeline from greenery import parse from greenery.parse import NoMatch from listener import Listener, ListenerOutput import time import json import torch class EndpointHandler: def __init__(self, path=""): self.listener = Listener(path, { "do_sample": True, "max_new_tokens": 128, "top_p": 0.9, "num_return_sequences": 500, "num_beams": 1 }, device="cuda" if torch.cuda.is_available() else "cpu") def __call__(self, data): # get inputs inp = data.pop("inputs", None) spec = inp["spec"] true_program = inp["true_program"] start = time.time() outputs = self.listener.synthesize([[(s["string"], s["label"]) for s in spec]], return_scores=True) consistent_program_scores = [outputs.decoded_scores[0][i] for i in outputs.idx[0]] consistent_programs = [outputs.decoded[0][i] for i in outputs.idx[0]] sorted_programs = sorted(set(zip(consistent_program_scores, consistent_programs)), reverse=True, key=lambda x: x[0]) end = time.time() top_guess = None top_score = None top_success = False top_10_guesses = None top_10_scores = None top_10_success = False if len(sorted_programs) > 0: top_guess = sorted_programs[0][1] top_score = sorted_programs[0][0] top_success = parse(top_guess.replace('\\', '')).equivalent(parse(true_program.replace('\\', ''))) top_10_guesses = [p for s, p in sorted_programs[:10]] top_10_scores = [s for s, p in sorted_programs[:10]] top_10_success = any([parse(p.replace('\\', '')).equivalent(parse(true_program.replace('\\', ''))) for p in top_10_guesses]) return { "guess": top_guess, "top_1_success": top_success, "top_1_score": top_score, "top_10_guesses": top_10_guesses, "top_10_scores": top_10_scores, "top_10_success": top_10_success, "time": end - start }