|
from torch_grammar import GrammarSampler |
|
from transformers.generation.logits_process import LogitsProcessor |
|
|
|
from modules import shared |
|
|
|
sampler = None |
|
grammar = None |
|
grammar_string = '' |
|
|
|
|
|
class GrammarLogitsProcessor(LogitsProcessor): |
|
def __init__(self, string): |
|
|
|
global sampler, grammar, grammar_string |
|
|
|
if string != grammar_string: |
|
grammar_string = string |
|
if string.strip() != '': |
|
string = string.strip() + '\n' |
|
sampler = GrammarSampler(string, 'root', shared.tokenizer) |
|
else: |
|
sampler = None |
|
|
|
if sampler is not None: |
|
grammar = sampler.logits_processor() |
|
else: |
|
grammar = None |
|
|
|
def __call__(self, input_ids, scores): |
|
if grammar is not None: |
|
scores = grammar(input_ids, scores) |
|
|
|
return scores |
|
|