from typing import Dict, List, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor from transformers_gad.grammar_utils import IncrementalGrammarConstraint from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor def safe_int_cast(str, default): try: return int(str) except ValueError: return default class EndpointHandler(): def __init__(self, path=""): # Preload DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float32 self.device = torch.device(DEVICE) self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained(path) self.model.to(self.device) self.model.to(dtype=DTYPE) self.model.resize_token_embeddings(len(self.tokenizer)) self.model = torch.compile(self.model, mode='reduce-overhead', fullgraph=True) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # do it! MAX_NEW_TOKENS=512 MAX_TIME=30 TEMPERATURE = 1.0 REPETITION_PENALTY = 1.0 TOP_P = 1.0 TOP_K = 0 inputs = data.get("inputs", data) grammar_str = data.get("grammar", "") max_new_tokens = safe_int_cast(data.get("max-new-tokens"), MAX_NEW_TOKENS) max_time = safe_int_cast(data.get("max-time"), MAX_TIME) if grammar_str is None or len(grammar_str) == 0 or grammar_str.isspace(): logits_processors = None gad_oracle_processor = None else: print("=== GOT GRAMMAR ===") print(grammar_str) print("===================") grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer) # Initialize logits processor for the grammar gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar) inf_nan_remove_processor = InfNanRemoveLogitsProcessor() logits_processors = LogitsProcessorList([ inf_nan_remove_processor, gad_oracle_processor, ]) #input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"] input_ids = self.tokenizer.apply_chat_template( [{"role": "user", "content": inputs}], tokenize=True, add_generation_prompt=True, return_tensors="pt" ) input_ids = input_ids.to(self.model.device) output = self.model.generate( input_ids, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, max_time=max_time, max_new_tokens=max_new_tokens, top_p=TOP_P, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, logits_processor=logits_processors, num_return_sequences=1, return_dict_in_generate=True, output_scores=True ) if gad_oracle_processor is not None: gad_oracle_processor.reset() # Detokenize generated output input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1] if (hasattr(output, "sequences")): generated_tokens = output.sequences[:, input_length:] else: generated_tokens = output[:, input_length:] generations = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return generations