"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" # flake8: noqa from pathlib import Path import sys from ctypes import * # type: ignore from enum import Enum from itertools import islice, groupby from typing import ( Any, Callable, Dict, Set, Generic, List, Optional, OrderedDict, TextIO, Tuple, TypeVar, Union, overload, ) import llama_cpp.llama_cpp as llama_cpp # Type aliases llama_grammar_element = llama_cpp.llama_grammar_element llama_grammar_element_p = llama_cpp.llama_grammar_element_p llama_grammar_p = llama_cpp.llama_grammar_p # Type variables Ptr = TypeVar("Ptr", bound="const_char_p") T = TypeVar("T") U = TypeVar("U") V = TypeVar("V") W = TypeVar("W") class Sentinel: """Used to mark the end of a iterator of std::vector & std::map.""" class LlamaGrammar: """Keeps reference counts of all the arguments, so that they are not garbage collected by Python.""" def __del__(self) -> None: """Free the grammar pointer when the object is deleted.""" if self.grammar is not None: llama_cpp.llama_grammar_free(self.grammar) self.grammar = None def __init__( self, parsed_grammar: "parse_state", ) -> None: """Initialize the grammar pointer from the parsed state.""" self._grammar_rules = ( parsed_grammar.c_rules() ) # type: std.vector[std.vector[LlamaGrammarElement]] self._n_rules = self._grammar_rules.size() # type: int self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int self.init() @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": """Convert a GBNF grammar to a Llama grammar.""" parsed_grammar = parse(const_char_p(grammar)) # type: parse_state if parsed_grammar.rules.empty(): raise ValueError( f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" ) if verbose: print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) print_grammar(sys.stderr, parsed_grammar) print(file=sys.stderr) return cls(parsed_grammar) @classmethod def from_json_schema( cls, json_schema: str, verbose: bool = True, ) -> "LlamaGrammar": """Convert a JSON schema to a Llama grammar.""" return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose) @classmethod def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": try: with open(file) as f: grammar = f.read() except Exception as err: raise Exception( f"{cls.from_file.__name__}: error reading grammar file: {err}" ) if grammar: return cls.from_string(grammar, verbose=verbose) raise ValueError( f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) def init(self) -> None: # Step 1: Convert LlamaGrammarElement to llama_grammar_element self._element_lists = [ [ llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value)) for elem in subvector ] for subvector in self._grammar_rules ] # type: List[List[llama_grammar_element]] # Step 2: Convert each list to llama_grammar_element array and get pointer self._element_arrays = [ (llama_grammar_element * len(sublist))(*sublist) for sublist in self._element_lists ] # type: List[Array[llama_grammar_element]] # Step 3: Get pointer of each array self._element_array_pointers = [ cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays ] # type: List[llama_grammar_element_p] # Step 4: Make array of these pointers and get its pointer self._rules = (llama_grammar_element_p * len(self._element_array_pointers))( *self._element_array_pointers ) self.grammar = llama_cpp.llama_grammar_init( self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index) ) def reset(self) -> None: if self.grammar is not None: llama_cpp.llama_grammar_free(self.grammar) self.init() class LlamaGrammarElement: def __init__(self, type: "llama_gretype", value: int): self.type = type self.value = value # Unicode code point or rule ID class const_char_p: """C++ implementation of const char *.""" def __init__(self, value: Union[str, Ptr], move: Optional[int] = None): if isinstance(value, const_char_p): # We're copying an existing const_char_p self.value = value.value self.pos = value.pos + (move or 0) return # We're creating a new const_char_p self.value = value self.pos = move or 0 def __str__(self) -> str: assert self.value is not None, "null pointer" return self.value[self.pos :] def __getitem__(self, index: int) -> str: value = str(self) return value[index] if index < len(value) else "" @overload def __add__(self: Ptr, other: int) -> Ptr: ... @overload def __add__(self: Ptr, other: Ptr) -> int: ... def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: return ( self.__class__(self.value, self.pos + other) if isinstance(other, int) else self.pos + other.pos ) @overload def __sub__(self: Ptr, other: int) -> Ptr: ... @overload def __sub__(self: Ptr, other: Ptr) -> int: ... def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: return ( self.__class__(self.value, self.pos - other) if isinstance(other, int) else self.pos - other.pos ) def __eq__(self: Ptr, other: Ptr) -> bool: assert self.value == other.value, "comparing pointers from different strings" return self.pos == other.pos def __lt__(self: Ptr, other: Ptr) -> bool: assert self.value == other.value, "comparing pointers from different strings" return self.pos < other.pos def __gt__(self: Ptr, other: Ptr) -> bool: assert self.value == other.value, "comparing pointers from different strings" return self.pos > other.pos class std: @staticmethod def string(ptr: const_char_p, length: Optional[int] = None) -> str: """C++ implementation of std::string constructor.""" value = str(ptr) if length is not None: value = value[:length] return value class vector(Generic[T], List[T]): """C++ implementation of std::vector.""" class iterator: def __init__(self, vector: "std.vector[T]", index: int): self._vector = vector self._index = index self._version = vector._version def _check_version(self): if self._version != self._vector._version: raise RuntimeError("Iterator used after vector was modified.") def __iter__(self): return self def __next__(self) -> T: self._check_version() if self._index >= self._vector.size(): raise StopIteration value = self._vector[self._index] self._index += 1 return value def __add__(self, value: int) -> "std.vector[T].iterator": return self.__class__(self._vector, self._index + value) def __sub__(self, value: int) -> "std.vector[T].iterator": return self.__class__(self._vector, self._index - value) def __init__(self): self._version = 0 def modify(self): # This is a bit of a hack to make sure iterators are invalidated self._version += 1 def push_back(self, value: T) -> None: self.modify() self.append(value) def pop_back(self) -> None: self.modify() if not self.empty(): self.pop() def back(self) -> T: return self[-1] def size(self) -> int: return len(self) def clear(self) -> None: self.modify() super().clear() def empty(self) -> bool: return self.size() == 0 def data(self) -> "std.vector[T]": return self def resize( self, new_size: int, fill_value_factory: Optional[Callable[[], T]] = None, ) -> None: if new_size > self.size(): if fill_value_factory is None: raise ValueError("A fill value factory function must be provided.") self.reserve(new_size, fill_value_factory) elif new_size < self.size(): self[:] = self[:new_size] def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None: if capacity > self.size(): fill_value = fill_value_factory() self.extend([fill_value] * (capacity - self.size())) def front(self) -> T: if not self.empty(): return self[0] else: raise IndexError("Vector is empty.") def assign(self, count: int, value: T) -> None: self.clear() self.extend([value] * count) def insert( self, pos: "std.vector[T].iterator", first: "std.vector[T].iterator", last: "std.vector[T].iterator", ) -> None: self[pos._index : pos._index] = list( islice(first._vector, first._index, last._index) ) def begin(self) -> "std.vector[T].iterator": return self.iterator(self, 0) def end(self) -> "std.vector[T].iterator": return self.iterator(self, self.size()) class map(Generic[T, U], OrderedDict[T, U]): """C++ implementation of std::map.""" class iterator(Generic[V, W]): def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]): self._map = _map self.iter = iter(_map) self.key = key self._advance() def _sanitize_key(self) -> T: if isinstance(self.key, Sentinel): raise StopIteration return self.key def _advance(self) -> None: try: while next(self.iter) != self.key: pass except StopIteration: self.key = Sentinel() def __next__(self) -> Tuple[T, U]: key = self._sanitize_key() if key in self._map: value = self._map[key] self._advance() return key, value else: raise StopIteration def get(self) -> Tuple[T, U]: key = self._sanitize_key() return key, self._map[key] @property def first(self) -> T: return self._sanitize_key() @property def second(self) -> U: return self._map[self._sanitize_key()] def insert( self, key: T, value: U ) -> Tuple["std.map[T, U].iterator[T, U]", bool]: if key in self: return self.iterator(self, key), False else: self[key] = value return self.iterator(self, key), True def find(self, key: T) -> "std.map[T, U].iterator[T, U]": if key in self: return self.iterator(self, key) else: return self.end() def at(self, key: T) -> U: if key in self: return self[key] else: raise KeyError("The provided key is not found in the map.") def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None: key = iterator.first if key in self: del self[key] def size(self) -> int: return len(self) def empty(self) -> bool: return self.size() == 0 def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]": try: keys = sorted(list(self.keys())) # type: ignore for k in keys: if k >= key: return self.iterator(self, k) raise ValueError("No key found that is not less than the input key") except TypeError: raise TypeError("Keys of type T cannot be sorted.") def begin(self) -> "std.map[T, U].iterator[T, U]": return self.iterator(self, next(iter(self))) def end(self) -> "std.map[T, U].iterator[T, U]": return self.iterator(self, Sentinel()) # // grammar element type # enum llama_gretype { # // end of rule definition # LLAMA_GRETYPE_END = 0, # // start of alternate definition for rule # LLAMA_GRETYPE_ALT = 1, # // non-terminal element: reference to rule # LLAMA_GRETYPE_RULE_REF = 2, # // terminal element: character (code point) # LLAMA_GRETYPE_CHAR = 3, # // inverse char(s) ([^a], [^a-b] [^abc]) # LLAMA_GRETYPE_CHAR_NOT = 4, # // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to # // be an inclusive range ([a-z]) # LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, # // modifies a preceding LLAMA_GRETYPE_CHAR or # // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) # LLAMA_GRETYPE_CHAR_ALT = 6, # }; class llama_gretype(Enum): """grammar element type""" LLAMA_GRETYPE_END = 0 # end of rule definition LLAMA_GRETYPE_ALT = 1 # start of alternate definition for rule LLAMA_GRETYPE_RULE_REF = 2 # non-terminal element: reference to rule LLAMA_GRETYPE_CHAR = 3 # terminal element: character (code point) LLAMA_GRETYPE_CHAR_NOT = 4 # inverse char(s) ([^a], [^a-b] [^abc]) LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z]) LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) # struct parse_state { # std::map symbol_ids; # std::vector> rules; # std::vector c_rules(); # }; class parse_state: def __init__(self): self.symbol_ids: std.map[str, int] = std.map() self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector() # std::vector parse_state::c_rules() { # std::vector ret; # for (const auto & rule : rules) { # ret.push_back(rule.data()); # } # return ret; # } def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]: ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]] for rule in self.rules: ret.push_back(rule.data()) return ret def __repr__(self) -> str: return ( f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" ) # struct llama_grammar { # const std::vector> rules; # std::vector> stacks; # }; # class llama_grammar: # def __init__( # self, # rules: std.vector[std.vector[llama_grammar_element]], # stacks: std.vector[std.vector[llama_grammar_element]], # ): # self.rules = rules # self.stacks = stacks # uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { # uint32_t next_id = static_cast(state.symbol_ids.size()); # auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); # return result.first->second; # } def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int: next_id = state.symbol_ids.size() # type: int result = state.symbol_ids.insert(std.string(src, len), next_id) return result[0].second # type: ignore # uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { # uint32_t next_id = static_cast(state.symbol_ids.size()); # state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; # return next_id; # } def generate_symbol_id(state: parse_state, base_name: str) -> int: next_id = state.symbol_ids.size() # type: int state.symbol_ids[base_name + "_" + str(next_id)] = next_id return next_id # void add_rule( # parse_state & state, # uint32_t rule_id, # const std::vector & rule) { # if (state.rules.size() <= rule_id) { # state.rules.resize(rule_id + 1); # } # state.rules[rule_id] = rule; # } def add_rule( state: parse_state, rule_id: int, rule: std.vector[LlamaGrammarElement], ) -> None: if state.rules.size() <= rule_id: state.rules.resize( rule_id + 1, fill_value_factory=std.vector[LlamaGrammarElement], ) state.rules[rule_id] = rule # std::pair decode_utf8(const char * src) { # static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; # uint8_t first_byte = static_cast(*src); # uint8_t highbits = first_byte >> 4; # int len = lookup[highbits]; # uint8_t mask = (1 << (8 - len)) - 1; # uint32_t value = first_byte & mask; # const char * end = src + len; // may overrun! # const char * pos = src + 1; # for ( ; pos < end && *pos; pos++) { # value = (value << 6) + (static_cast(*pos) & 0x3F); # } # return std::make_pair(value, pos); # } def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: """Decodes a UTF-8 character from the source string.""" # Get the codepoint of the first character value = ord(src[0]) # Move the pointer ahead one character pos = src + 1 return value, pos # bool is_word_char(char c) { # return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); # } def is_word_char(c: str) -> bool: return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") # std::pair parse_hex(const char * src, int size) { # const char * pos = src; # const char * end = src + size; # uint32_t value = 0; # for ( ; pos < end && *pos; pos++) { # value <<= 4; # char c = *pos; # if ('a' <= c && c <= 'f') { # value += c - 'a' + 10; # } else if ('A' <= c && c <= 'F') { # value += c - 'A' + 10; # } else if ('0' <= c && c <= '9') { # value += c - '0'; # } else { # break; # } # } # if (pos != end) { # throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); # } # return std::make_pair(value, pos); # } def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: pos = const_char_p(src) # type: const_char_p end = src + size # type: const_char_p value = 0 # type: int while pos < end and pos[0]: value <<= 4 c = pos[0] # type: str if "a" <= c <= "f": value += ord(c) - ord("a") + 10 elif "A" <= c <= "F": value += ord(c) - ord("A") + 10 elif "0" <= c <= "9": value += ord(c) - ord("0") else: break pos += 1 if pos != end: raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src)) return (value, pos) # std::pair parse_char(const char * src) { # if (*src == '\\') { # switch (src[1]) { # case 'x': return parse_hex(src + 2, 2); # case 'u': return parse_hex(src + 2, 4); # case 'U': return parse_hex(src + 2, 8); # case 't': return std::make_pair('\t', src + 2); # case 'r': return std::make_pair('\r', src + 2); # case 'n': return std::make_pair('\n', src + 2); # case '\\': # case '"': # case '[': # case ']': # return std::make_pair(src[1], src + 2); # default: # throw std::runtime_error(std::string("unknown escape at ") + src); # } # } else if (*src) { # return decode_utf8(src); # } # throw std::runtime_error("unexpected end of input"); # } def parse_char(src: const_char_p) -> Tuple[int, const_char_p]: if src[0] == "\\": case = src[1] # type: str if case == "x": return parse_hex(src + 2, 2) elif case == "u": return parse_hex(src + 2, 4) elif case == "U": return parse_hex(src + 2, 8) elif case == "t": return (ord("\t"), src + 2) # implicit cast elif case == "r": return (ord("\r"), src + 2) # implicit cast elif case == "n": return (ord("\n"), src + 2) # implicit cast elif case in ("\\", '"', "[", "]"): return (ord(case), src + 2) # implicit cast else: raise RuntimeError("unknown escape at " + str(src)) elif src[0]: return decode_utf8(src) else: raise RuntimeError("unexpected end of input") # const char * parse_name(const char * src) { # const char * pos = src; # while (is_word_char(*pos)) { # pos++; # } # if (pos == src) { # throw std::runtime_error(std::string("expecting name at ") + src); # } # return pos; # } def parse_name(src: const_char_p) -> const_char_p: pos = const_char_p(src) # type: const_char_p while is_word_char(pos[0]): pos += 1 if pos == src: raise RuntimeError("expecting name at " + str(src)) return pos # const char * parse_space(const char * src, bool newline_ok) { # const char * pos = src; # while (*pos == ' ' || *pos == '\t' || *pos == '#' || # (newline_ok && (*pos == '\r' || *pos == '\n'))) { # if (*pos == '#') { # while (*pos && *pos != '\r' && *pos != '\n') { # pos++; # } # } else { # pos++; # } # } # return pos; # } def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: pos = const_char_p(src) # type: const_char_p while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")): if pos[0] == "#": while pos[0] is not None and pos[0] not in ("\r", "\n"): pos += 1 else: pos += 1 return pos # const char * parse_sequence( # parse_state & state, # const char * src, # const std::string & rule_name, # std::vector & out_elements, # bool is_nested) { def parse_sequence( state: parse_state, src: const_char_p, rule_name: str, out_elements: std.vector[LlamaGrammarElement], is_nested: bool, ) -> const_char_p: # size_t last_sym_start = out_elements.size(); # const char * pos = src; last_sym_start = out_elements.size() # type: int pos = const_char_p(src) # type: const_char_p # while (*pos) { while pos[0]: # if (*pos == '"') { // literal string # pos++; # last_sym_start = out_elements.size(); # while (*pos != '"') { # auto char_pair = parse_char(pos); # pos = char_pair.second; # out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); # } # pos = parse_space(pos + 1, is_nested); if pos[0] == '"': # literal string pos += 1 last_sym_start = out_elements.size() while pos[0] != '"': char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] out_elements.push_back( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0]) ) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '[') { // char range(s) # pos++; # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; elif pos[0] == "[": # char range(s) pos += 1 start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype # if (*pos == '^') { # pos++; # start_type = LLAMA_GRETYPE_CHAR_NOT; # } # last_sym_start = out_elements.size(); if pos[0] == "^": pos += 1 start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT last_sym_start = out_elements.size() # while (*pos != ']') { # auto char_pair = parse_char(pos); # pos = char_pair.second; # enum llama_gretype type = last_sym_start < out_elements.size() # ? LLAMA_GRETYPE_CHAR_ALT # : start_type; # out_elements.push_back({type, char_pair.first}); while pos[0] != "]": char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] type = ( llama_gretype.LLAMA_GRETYPE_CHAR_ALT if last_sym_start < out_elements.size() else start_type ) # type: llama_gretype out_elements.push_back(LlamaGrammarElement(type, char_pair[0])) # if (pos[0] == '-' && pos[1] != ']') { # auto endchar_pair = parse_char(pos + 1); # pos = endchar_pair.second; # out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); # } # } if pos[0] == "-" and pos[1] != "]": endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] pos = endchar_pair[1] out_elements.push_back( LlamaGrammarElement( llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair[0], ) ) # pos = parse_space(pos + 1, is_nested); pos = parse_space(pos + 1, is_nested) # } else if (is_word_char(*pos)) { // rule reference # const char * name_end = parse_name(pos); # uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); # pos = parse_space(name_end, is_nested); # last_sym_start = out_elements.size(); # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); elif is_word_char(pos[0]): # rule reference name_end = parse_name(pos) # type: const_char_p ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int pos = parse_space(name_end, is_nested) last_sym_start = out_elements.size() out_elements.push_back( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id) ) # } else if (*pos == '(') { // grouping # // parse nested alternates into synthesized rule # pos = parse_space(pos + 1, true); # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); # pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); # last_sym_start = out_elements.size(); # // output reference to synthesized rule # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); # if (*pos != ')') { # throw std::runtime_error(std::string("expecting ')' at ") + pos); # } # pos = parse_space(pos + 1, is_nested); elif pos[0] == "(": # grouping # parse nested alternates into synthesized rule pos = parse_space(pos + 1, True) sub_rule_id = generate_symbol_id(state, rule_name) # type: int pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) last_sym_start = out_elements.size() # output reference to synthesized rule out_elements.push_back( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) if pos[0] != ")": raise RuntimeError("expecting ')' at " + str(pos)) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator # if (last_sym_start == out_elements.size()) { # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); # } elif pos[0] in ("*", "+", "?"): # repetition operator if last_sym_start == out_elements.size(): raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) # // apply transformation to previous symbol (last_sym_start to end) according to # // rewrite rules: # // S* --> S' ::= S S' | # // S+ --> S' ::= S S' | S # // S? --> S' ::= S | # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); # std::vector sub_rule; # // add preceding symbol to generated rule # sub_rule.insert( # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); sub_rule_id = generate_symbol_id(state, rule_name) # type: int sub_rule = std.vector[ LlamaGrammarElement ]() # type: std.vector[LlamaGrammarElement] sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end(), ) # if (*pos == '*' || *pos == '+') { # // cause generated rule to recurse # sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); # } # // mark start of alternate def # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); if pos[0] in ("*", "+"): sub_rule.push_back( LlamaGrammarElement( llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id ) ) sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) # if (*pos == '+') { # // add preceding symbol as alternate only for '+' (otherwise empty) # sub_rule.insert( # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); # } # sub_rule.push_back({LLAMA_GRETYPE_END, 0}); # add_rule(state, sub_rule_id, sub_rule); # // in original rule, replace previous symbol with reference to generated rule # out_elements.resize(last_sym_start); # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); # pos = parse_space(pos + 1, is_nested); if pos[0] == "+": # add preceding symbol as alternate only for '+' (otherwise empty) sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end(), ) sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, sub_rule_id, sub_rule) # in original rule, replace previous symbol with reference to generated rule out_elements.resize(last_sym_start) out_elements.push_back( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) pos = parse_space(pos + 1, is_nested) # } else { # break; # } else: break # } # return pos; # } return pos # const char * parse_alternates( # parse_state & state, # const char * src, # const std::string & rule_name, # uint32_t rule_id, # bool is_nested) { # std::vector rule; # const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); # while (*pos == '|') { # rule.push_back({LLAMA_GRETYPE_ALT, 0}); # pos = parse_space(pos + 1, true); # pos = parse_sequence(state, pos, rule_name, rule, is_nested); # } # rule.push_back({LLAMA_GRETYPE_END, 0}); # add_rule(state, rule_id, rule); # return pos; # } def parse_alternates( state: parse_state, src: const_char_p, rule_name: str, rule_id: int, is_nested: bool, ) -> const_char_p: rule = std.vector() # type: std.vector[LlamaGrammarElement] pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p while pos[0] == "|": rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) pos = parse_space(pos + 1, True) pos = parse_sequence(state, pos, rule_name, rule, is_nested) rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, rule_id, rule) return pos # const char * parse_rule(parse_state & state, const char * src) { # const char * name_end = parse_name(src); # const char * pos = parse_space(name_end, false); # size_t name_len = name_end - src; # uint32_t rule_id = get_symbol_id(state, src, name_len); # const std::string name(src, name_len); # if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { # throw std::runtime_error(std::string("expecting ::= at ") + pos); # } # pos = parse_space(pos + 3, true); # pos = parse_alternates(state, pos, name, rule_id, false); # if (*pos == '\r') { # pos += pos[1] == '\n' ? 2 : 1; # } else if (*pos == '\n') { # pos++; # } else if (*pos) { # throw std::runtime_error(std::string("expecting newline or end at ") + pos); # } # return parse_space(pos, true); # } def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: name_end = parse_name(src) # type: const_char_p pos = parse_space(name_end, False) # type: const_char_p name_len = name_end - src # type: int rule_id = get_symbol_id(state, src, name_len) # type: int name = std.string(src, name_len) # type: str if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="): raise RuntimeError("expecting ::= at " + str(pos)) pos = parse_space(pos + 3, True) # type: const_char_p pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p if pos[0] == "\r": pos += 2 if pos[1] == "\n" else 1 elif pos[0] == "\n": pos += 1 elif pos[0]: raise RuntimeError("expecting newline or end at " + str(pos)) return parse_space(pos, True) # parse_state parse(const char * src) { # try { # parse_state state; # const char * pos = parse_space(src, true); # while (*pos) { # pos = parse_rule(state, pos); # } # return state; # } catch (const std::exception & err) { # fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); # return parse_state(); # } # } def parse(src: const_char_p) -> parse_state: try: state = parse_state() # type: parse_state pos = parse_space(src, True) # type: const_char_p while pos[0]: pos = parse_rule(state, pos) return state except Exception as err: print(f"{parse.__name__}: error parsing grammar: {err}") return parse_state() # void print_grammar_char(FILE * file, uint32_t c) { # if (0x20 <= c && c <= 0x7f) { # fprintf(file, "%c", static_cast(c)); # } else { # // cop out of encoding UTF-8 # fprintf(file, "", c); # } # } def print_grammar_char(file: TextIO, c: int) -> None: if 0x20 <= c and c <= 0x7F: file.write(chr(c)) else: # cop out of encoding UTF-8 file.write(f"") # bool is_char_element(llama_grammar_element elem) { # switch (elem.type) { # case LLAMA_GRETYPE_CHAR: return true; # case LLAMA_GRETYPE_CHAR_NOT: return true; # case LLAMA_GRETYPE_CHAR_ALT: return true; # case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; # default: return false; # } # } def is_char_element(elem: LlamaGrammarElement) -> bool: return elem.type in ( llama_gretype.LLAMA_GRETYPE_CHAR, llama_gretype.LLAMA_GRETYPE_CHAR_NOT, llama_gretype.LLAMA_GRETYPE_CHAR_ALT, llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, ) # void print_rule( # FILE * file, # uint32_t rule_id, # const std::vector & rule, # const std::map & symbol_id_names) { def print_rule( file: TextIO, rule_id: int, rule: std.vector[LlamaGrammarElement], symbol_id_names: std.map[int, str], ) -> None: # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { # throw std::runtime_error( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError( "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { # llama_grammar_element elem = rule[i]; # switch (elem.type) { # case LLAMA_GRETYPE_END: # throw std::runtime_error( # "unexpected end of rule: " + std::to_string(rule_id) + "," + # std::to_string(i)); # case LLAMA_GRETYPE_ALT: # fprintf(file, "| "); # break; # case LLAMA_GRETYPE_RULE_REF: # fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); # break; # case LLAMA_GRETYPE_CHAR: # fprintf(file, "["); # print_grammar_char(file, elem.value); # break; # case LLAMA_GRETYPE_CHAR_NOT: # fprintf(file, "[^"); # print_grammar_char(file, elem.value); # break; # case LLAMA_GRETYPE_CHAR_RNG_UPPER: # if (i == 0 || !is_char_element(rule[i - 1])) { # throw std::runtime_error( # "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + # std::to_string(rule_id) + "," + std::to_string(i)); # } # fprintf(file, "-"); # print_grammar_char(file, elem.value); # break; # case LLAMA_GRETYPE_CHAR_ALT: # if (i == 0 || !is_char_element(rule[i - 1])) { # throw std::runtime_error( # "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + # std::to_string(rule_id) + "," + std::to_string(i)); # } # print_grammar_char(file, elem.value); # break; # } for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_CHAR: print("[", file=file, end="") print_grammar_char(file, elem.value) elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT: print("[^", file=file, end="") print_grammar_char(file, elem.value) elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + str(rule_id) + "," + str(i) ) print("-", file=file, end="") print_grammar_char(file, elem.value) elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + str(rule_id) + "," + str(i) ) print_grammar_char(file, elem.value) # if (is_char_element(elem)) { # switch (rule[i + 1].type) { # case LLAMA_GRETYPE_CHAR_ALT: # case LLAMA_GRETYPE_CHAR_RNG_UPPER: # break; # default: # fprintf(file, "] "); if is_char_element(elem): if rule[i + 1].type in ( llama_gretype.LLAMA_GRETYPE_CHAR_ALT, llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, ): pass else: print("] ", file=file, end="") # } # } # } # fprintf(file, "\n"); # } print(file=file) # void print_grammar(FILE * file, const parse_state & state) { # try { # std::map symbol_id_names; # for (auto kv : state.symbol_ids) { # symbol_id_names[kv.second] = kv.first; # } # for (size_t i = 0, end = state.rules.size(); i < end; i++) { # // fprintf(file, "%zu: ", i); # // print_rule_binary(file, state.rules[i]); # print_rule(file, i, state.rules[i], symbol_id_names); # // fprintf(file, "\n"); # } # } catch (const std::exception & err) { # fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); # } # } def print_grammar(file: TextIO, state: parse_state) -> None: try: symbol_id_names = std.map() # type: std.map[int, str] for kv in state.symbol_ids.items(): symbol_id_names[kv[1]] = kv[0] for i, rule in enumerate(state.rules): print_rule(file, i, rule, symbol_id_names) except Exception as err: print( f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, ) """llama.cpp gbnf rules from vendor/llama.cpp/grammars""" ARITHMETIC_GBNF = r""" root ::= (expr "=" ws term "\n")+ expr ::= term ([-+*/] term)* term ::= ident | num | "(" ws expr ")" ws ident ::= [a-z] [a-z0-9_]* ws num ::= [0-9]+ ws ws ::= [ \t\n]* """ C_GBNF = r""" root ::= (declaration)* declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" dataType ::= "int" ws | "float" ws | "char" ws identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* parameter ::= dataType identifier statement ::= ( dataType identifier ws "=" ws expression ";" ) | ( identifier ws "=" ws expression ";" ) | ( identifier ws "(" argList? ")" ";" ) | ( "return" ws expression ";" ) | ( "while" "(" condition ")" "{" statement* "}" ) | ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | ( singleLineComment ) | ( multiLineComment ) forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression forUpdate ::= identifier ws "=" ws expression condition ::= expression relationOperator expression relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") expression ::= term (("+" | "-") term)* term ::= factor(("*" | "/") factor)* factor ::= identifier | number | unaryTerm | funcCall | parenExpression unaryTerm ::= "-" factor funcCall ::= identifier "(" argList? ")" parenExpression ::= "(" ws expression ws ")" argList ::= expression ("," ws expression)* number ::= [0-9]+ singleLineComment ::= "//" [^\n]* "\n" multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" ws ::= ([ \t\n]+) """ CHESS_GBNF = r""" root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= ([ \t\n] ws)? """ JAPANESE_GBNF = r""" root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= ([ \t\n] ws)? """ JSON_ARR_GBNF = r""" # This is the same as json.gbnf but we restrict whitespaces at the end of the root array # Useful for generating JSON arrays root ::= arr value ::= object | array | string | number | ("true" | "false" | "null") ws arr ::= "[\n" ws ( value (",\n" ws value)* )? "]" object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\\x7F\x00-\x1F] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= ([ \t\n] ws)? """ JSON_GBNF = r""" root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws array ::= "[" ws ( value ("," ws value)* )? "]" ws string ::= "\"" ( [^"\\\x7F\x00-\x1F] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws ws ::= ([ \t\n] ws)? """ LIST_GBNF = r""" root ::= item+ # Excludes various line break characters item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" """ """llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" import json import re from typing import List, Optional # whitespace is constrained to a single space char to prevent model "running away" in # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} # whitespace is constrained to a single space char to prevent model "running away" in # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): if not separator_rule: if min_items == 0 and max_items == 1: return f'{item_rule}?' elif min_items == 1 and max_items is None: return f'{item_rule}+' result = '' if min_items > 0: if item_rule_is_literal and separator_rule is None: result = '"' + (item_rule[1:-1] * min_items) + '"' else: result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) def opt_repetitions(up_to_n, prefix_with_sep=False): ''' - n=4, no sep: '(a (a (a (a)?)?)?)?' - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' ''' content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule if up_to_n == 0: return '' elif up_to_n == 1: return f'({content})?' elif separator_rule and not prefix_with_sep: return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' else: return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) if min_items > 0 and max_items != min_items: result += ' ' if max_items is not None: result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) else: item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' if min_items == 0 and separator_rule: result = f'({item_rule} {item_operator}*)?' else: result += f'{item_operator}*' return result class BuiltinRule: def __init__(self, content: str, deps: list = None): self.content = content self.deps = deps or [] _up_to_15_digits = _build_repetition('[0-9]', 0, 15) PRIMITIVE_RULES = { 'boolean' : BuiltinRule('("true" | "false") space', []), 'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), 'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), 'null' : BuiltinRule('"null" space', []), } # TODO: support "uri", "email" string formats STRING_FORMAT_RULES = { 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), } DOTALL = '[\\U00000000-\\U0010FFFF]' DOT = '[^\\x0A\\x0D]' RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) NON_LITERAL_SET = set('|.()[]{}*+?') ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') class SchemaConverter: def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): self._prop_order = prop_order self._allow_fetch = allow_fetch self._dotall = dotall self._raw_pattern = raw_pattern self._rules = { 'space': SPACE_RULE, } self._refs = {} self._refs_being_resolved = set() def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal ) return f'"{escaped}"' def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: ''' not_literal('a') -> '[^a]' not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' ''' assert len(literal) > 0, 'Empty literal not supported' def recurse(i: int): c = literal[i] if maybe_escaped_underscores and c == '_': yield f'[^{c}\\\\]' yield ' | ' yield f'"\\\\"? "{c}"' else: yield f'[^{c}]' if i < len(literal) - 1: yield ' | ' yield self._format_literal(c) yield ' (' yield from recurse(i + 1) yield ')?' return ''.join(('(', *recurse(0), ')')) def _add_rule(self, name, rule): esc_name = INVALID_RULE_CHARS_RE.sub('-', name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name else: i = 0 while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: i += 1 key = f'{esc_name}{i}' self._rules[key] = rule return key def resolve_refs(self, schema: dict, url: str): ''' Resolves all $ref fields in the given schema, fetching any remote schemas, replacing $ref with absolute reference URL and populating self._refs with the respective referenced (sub)schema dictionaries. ''' def visit(n: dict): if isinstance(n, list): return [visit(x) for x in n] elif isinstance(n, dict): ref = n.get('$ref') if ref is not None and ref not in self._refs: if ref.startswith('https://'): assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' import requests frag_split = ref.split('#') base_url = frag_split[0] target = self._refs.get(base_url) if target is None: target = self.resolve_refs(requests.get(ref).json(), base_url) self._refs[base_url] = target if len(frag_split) == 1 or frag_split[-1] == '': return target elif ref.startswith('#/'): target = schema ref = f'{url}{ref}' n['$ref'] = ref else: raise ValueError(f'Unsupported ref {ref}') for sel in ref.split('#')[-1].split('/')[1:]: assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' target = target[sel] self._refs[ref] = target else: for v in n.values(): visit(v) return n return visit(schema) def _generate_union_rule(self, name, alt_schemas): return ' | '.join(( self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') for i, alt_schema in enumerate(alt_schemas) )) def _visit_pattern(self, pattern, name): ''' Transforms a regular expression pattern into a GBNF rule. Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which we define sub-rules to keep the output lean. ''' assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' pattern = pattern[1:-1] sub_rule_ids = {} i = 0 length = len(pattern) def to_rule(s: Tuple[str, bool]) -> str: (txt, is_literal) = s return "\"" + txt + "\"" if is_literal else txt def transform() -> Tuple[str, bool]: ''' Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. ''' nonlocal i nonlocal pattern nonlocal sub_rule_ids start = i # For each component of this sequence, store its string representation and whether it's a literal. # We only need a flat structure here to apply repetition operators to the last item, and # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially # (GBNF's syntax is luckily very close to regular expressions!) seq: list[Tuple[str, bool]] = [] def get_dot(): if self._dotall: rule = DOTALL else: # Accept any character... except \n and \r line break chars (\x0A and \xOD) rule = DOT return self._add_rule(f'dot', rule) def join_seq(): nonlocal seq ret = [] for is_literal, g in groupby(seq, lambda x: x[1]): if is_literal: ret.append((''.join(x[0] for x in g), True)) else: ret.extend(g) if len(ret) == 1: return ret[0] return (' '.join(to_rule(x) for x in seq), False) while i < length: c = pattern[i] if c == '.': seq.append((get_dot(), False)) i += 1 elif c == '(': i += 1 if i < length: assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' seq.append((f'({to_rule(transform())})', False)) elif c == ')': i += 1 assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' return join_seq() elif c == '[': square_brackets = c i += 1 while i < length and pattern[i] != ']': if pattern[i] == '\\': square_brackets += pattern[i:i+2] i += 2 else: square_brackets += pattern[i] i += 1 assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' square_brackets += ']' i += 1 seq.append((square_brackets, False)) elif c == '|': seq.append(('|', False)) i += 1 elif c in ('*', '+', '?'): seq[-1] = (to_rule(seq[-1]) + c, False) i += 1 elif c == '{': curly_brackets = c i += 1 while i < length and pattern[i] != '}': curly_brackets += pattern[i] i += 1 assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' curly_brackets += '}' i += 1 nums = [s.strip() for s in curly_brackets[1:-1].split(',')] min_times = 0 max_times = None try: if len(nums) == 1: min_times = int(nums[0]) max_times = min_times else: assert len(nums) == 2 min_times = int(nums[0]) if nums[0] else 0 max_times = int(nums[1]) if nums[1] else None except ValueError: raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') (sub, sub_is_literal) = seq[-1] if not sub_is_literal: id = sub_rule_ids.get(sub) if id is None: id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) sub_rule_ids[sub] = id sub = id seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) else: literal = '' while i < length: if pattern[i] == '\\' and i < length - 1: next = pattern[i + 1] if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: i += 1 literal += pattern[i] i += 1 else: literal += pattern[i:i+2] i += 2 elif pattern[i] == '"' and not self._raw_pattern: literal += '\\"' i += 1 elif pattern[i] not in NON_LITERAL_SET and \ (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): literal += pattern[i] i += 1 else: break if literal: seq.append((literal, True)) return join_seq() return self._add_rule( name, to_rule(transform()) if self._raw_pattern \ else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") def _resolve_ref(self, ref): ref_name = ref.split('/')[-1] if ref_name not in self._rules and ref not in self._refs_being_resolved: self._refs_being_resolved.add(ref) resolved = self._refs[ref] ref_name = self.visit(resolved, ref_name) self._refs_being_resolved.remove(ref) return ref_name def _generate_constant_rule(self, value): return self._format_literal(json.dumps(value)) def visit(self, schema, name): schema_type = schema.get('type') schema_format = schema.get('format') rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' if (ref := schema.get('$ref')) is not None: return self._add_rule(rule_name, self._resolve_ref(ref)) elif 'oneOf' in schema or 'anyOf' in schema: return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) elif isinstance(schema_type, list): return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) elif 'const' in schema: return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) elif 'enum' in schema: rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) return self._add_rule(rule_name, rule) elif schema_type in (None, 'object') and \ ('properties' in schema or \ ('additionalProperties' in schema and schema['additionalProperties'] is not True)): required = set(schema.get('required', [])) properties = list(schema.get('properties', {}).items()) return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) elif schema_type in (None, 'object') and 'allOf' in schema: required = set() properties = [] hybrid_name = name def add_component(comp_schema, is_required): if (ref := comp_schema.get('$ref')) is not None: comp_schema = self._refs[ref] if 'properties' in comp_schema: for prop_name, prop_schema in comp_schema['properties'].items(): properties.append((prop_name, prop_schema)) if is_required: required.add(prop_name) for t in schema['allOf']: if 'anyOf' in t: for tt in t['anyOf']: add_component(tt, is_required=False) else: add_component(t, is_required=True) return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): items = schema.get('items') or schema['prefixItems'] if isinstance(items, list): return self._add_rule( rule_name, '"[" space ' + ' "," space '.join( self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') for i, item in enumerate(items)) + ' "]" space') else: item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') min_items = schema.get("minItems", 0) max_items = schema.get("maxItems") return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') elif schema_type in (None, 'string') and 'pattern' in schema: return self._visit_pattern(schema['pattern'], rule_name) elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): return self._add_primitive( 'root' if rule_name == 'root' else schema_format, PRIMITIVE_RULES['uuid'] ) elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: prim_name = f'{schema_format}-string' return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) min_len = schema.get('minLength', 0) max_len = schema.get('maxLength') return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') elif (schema_type == 'object') or (len(schema) == 0): return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) else: assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) def _add_primitive(self, name: str, rule: BuiltinRule): n = self._add_rule(name, rule.content) for dep in rule.deps: dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) assert dep_rule, f'Rule {dep} not known' if dep not in self._rules: self._add_primitive(dep, dep_rule) return n def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): prop_order = self._prop_order # sort by position in prop_order (if specified) then by original order sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] prop_kv_rule_names = {} for prop_name, prop_schema in properties: prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') prop_kv_rule_names[prop_name] = self._add_rule( f'{name}{"-" if name else ""}{prop_name}-kv', fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' ) required_props = [k for k in sorted_props if k in required] optional_props = [k for k in sorted_props if k not in required] if additional_properties == True or isinstance(additional_properties, dict): sub_name = f'{name}{"-" if name else ""}additional' value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') prop_kv_rule_names["*"] = self._add_rule( f'{sub_name}-kv', self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' ) optional_props.append("*") rule = '"{" space ' rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) if optional_props: rule += ' (' if required_props: rule += ' "," space ( ' def get_recursive_refs(ks, first_is_optional): [k, *rest] = ks kv_rule_name = prop_kv_rule_names[k] if k == '*': res = self._add_rule( f'{name}{"-" if name else ""}additional-kvs', f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' ) elif first_is_optional: res = f'( "," space {kv_rule_name} )?' else: res = kv_rule_name if len(rest) > 0: res += ' ' + self._add_rule( f'{name}{"-" if name else ""}{k}-rest', get_recursive_refs(rest, first_is_optional=True) ) return res rule += ' | '.join( get_recursive_refs(optional_props[i:], first_is_optional=False) for i in range(len(optional_props)) ) if required_props: rule += ' )' rule += ' )?' rule += ' "}" space' return rule def format_grammar(self): return '\n'.join( f'{name} ::= {rule}' for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) ) def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): prop_order = prop_order or [] schema = json.loads(schema) prop_order = {name: idx for idx, name in enumerate(prop_order)} converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False) schema = converter.resolve_refs(schema, "stdin") converter.visit(schema, "") return converter.format_grammar()