from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor from transformers_gad.grammar_utils import IncrementalGrammarConstraint from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor class EndpointHandler(): def __init__(self, path=""): # Preload self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # do it! inputs = data.get("inputs",data) grammar_str = data.get("grammar", "") MAX_NEW_TOKENS=512 MAX_TIME=30 print(grammar_str) grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer) # Initialize logits processor for the grammar gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar) inf_nan_remove_processor = InfNanRemoveLogitsProcessor() logits_processors = LogitsProcessorList([ inf_nan_remove_processor, gad_oracle_processor, ]) input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt")["input_ids"] output = self.model.generate( input_ids, do_sample=True, max_time=MAX_TIME, max_new_tokens=MAX_NEW_TOKENS, logits_processor=logits_processors ) gad_oracle_processor.reset() # Detokenize generated output input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1] if (hasattr(output, "sequences")): generated_tokens = output.sequences[:, input_length:] else: generated_tokens = output[:, input_length:] generations = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return generations