Spaces:
Runtime error
Runtime error
"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" | |
# flake8: noqa | |
from pathlib import Path | |
import sys | |
from ctypes import * # type: ignore | |
from enum import Enum | |
from itertools import islice, groupby | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Set, | |
Generic, | |
List, | |
Optional, | |
OrderedDict, | |
TextIO, | |
Tuple, | |
TypeVar, | |
Union, | |
overload, | |
) | |
import llama_cpp.llama_cpp as llama_cpp | |
# Type aliases | |
llama_grammar_element = llama_cpp.llama_grammar_element | |
llama_grammar_element_p = llama_cpp.llama_grammar_element_p | |
llama_grammar_p = llama_cpp.llama_grammar_p | |
# Type variables | |
Ptr = TypeVar("Ptr", bound="const_char_p") | |
T = TypeVar("T") | |
U = TypeVar("U") | |
V = TypeVar("V") | |
W = TypeVar("W") | |
class Sentinel: | |
"""Used to mark the end of a iterator of std::vector & std::map.""" | |
class LlamaGrammar: | |
"""Keeps reference counts of all the arguments, so that they are not | |
garbage collected by Python.""" | |
def __del__(self) -> None: | |
"""Free the grammar pointer when the object is deleted.""" | |
if self.grammar is not None: | |
llama_cpp.llama_grammar_free(self.grammar) | |
self.grammar = None | |
def __init__( | |
self, | |
parsed_grammar: "parse_state", | |
) -> None: | |
"""Initialize the grammar pointer from the parsed state.""" | |
self._grammar_rules = ( | |
parsed_grammar.c_rules() | |
) # type: std.vector[std.vector[LlamaGrammarElement]] | |
self._n_rules = self._grammar_rules.size() # type: int | |
self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int | |
self.init() | |
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": | |
"""Convert a GBNF grammar to a Llama grammar.""" | |
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state | |
if parsed_grammar.rules.empty(): | |
raise ValueError( | |
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" | |
) | |
if verbose: | |
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) | |
print_grammar(sys.stderr, parsed_grammar) | |
print(file=sys.stderr) | |
return cls(parsed_grammar) | |
def from_json_schema( | |
cls, | |
json_schema: str, | |
verbose: bool = True, | |
) -> "LlamaGrammar": | |
"""Convert a JSON schema to a Llama grammar.""" | |
return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose) | |
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": | |
try: | |
with open(file) as f: | |
grammar = f.read() | |
except Exception as err: | |
raise Exception( | |
f"{cls.from_file.__name__}: error reading grammar file: {err}" | |
) | |
if grammar: | |
return cls.from_string(grammar, verbose=verbose) | |
raise ValueError( | |
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" | |
) | |
def init(self) -> None: | |
# Step 1: Convert LlamaGrammarElement to llama_grammar_element | |
self._element_lists = [ | |
[ | |
llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value)) | |
for elem in subvector | |
] | |
for subvector in self._grammar_rules | |
] # type: List[List[llama_grammar_element]] | |
# Step 2: Convert each list to llama_grammar_element array and get pointer | |
self._element_arrays = [ | |
(llama_grammar_element * len(sublist))(*sublist) | |
for sublist in self._element_lists | |
] # type: List[Array[llama_grammar_element]] | |
# Step 3: Get pointer of each array | |
self._element_array_pointers = [ | |
cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays | |
] # type: List[llama_grammar_element_p] | |
# Step 4: Make array of these pointers and get its pointer | |
self._rules = (llama_grammar_element_p * len(self._element_array_pointers))( | |
*self._element_array_pointers | |
) | |
self.grammar = llama_cpp.llama_grammar_init( | |
self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index) | |
) | |
def reset(self) -> None: | |
if self.grammar is not None: | |
llama_cpp.llama_grammar_free(self.grammar) | |
self.init() | |
class LlamaGrammarElement: | |
def __init__(self, type: "llama_gretype", value: int): | |
self.type = type | |
self.value = value # Unicode code point or rule ID | |
class const_char_p: | |
"""C++ implementation of const char *.""" | |
def __init__(self, value: Union[str, Ptr], move: Optional[int] = None): | |
if isinstance(value, const_char_p): | |
# We're copying an existing const_char_p | |
self.value = value.value | |
self.pos = value.pos + (move or 0) | |
return | |
# We're creating a new const_char_p | |
self.value = value | |
self.pos = move or 0 | |
def __str__(self) -> str: | |
assert self.value is not None, "null pointer" | |
return self.value[self.pos :] | |
def __getitem__(self, index: int) -> str: | |
value = str(self) | |
return value[index] if index < len(value) else "" | |
def __add__(self: Ptr, other: int) -> Ptr: | |
... | |
def __add__(self: Ptr, other: Ptr) -> int: | |
... | |
def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: | |
return ( | |
self.__class__(self.value, self.pos + other) | |
if isinstance(other, int) | |
else self.pos + other.pos | |
) | |
def __sub__(self: Ptr, other: int) -> Ptr: | |
... | |
def __sub__(self: Ptr, other: Ptr) -> int: | |
... | |
def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: | |
return ( | |
self.__class__(self.value, self.pos - other) | |
if isinstance(other, int) | |
else self.pos - other.pos | |
) | |
def __eq__(self: Ptr, other: Ptr) -> bool: | |
assert self.value == other.value, "comparing pointers from different strings" | |
return self.pos == other.pos | |
def __lt__(self: Ptr, other: Ptr) -> bool: | |
assert self.value == other.value, "comparing pointers from different strings" | |
return self.pos < other.pos | |
def __gt__(self: Ptr, other: Ptr) -> bool: | |
assert self.value == other.value, "comparing pointers from different strings" | |
return self.pos > other.pos | |
class std: | |
def string(ptr: const_char_p, length: Optional[int] = None) -> str: | |
"""C++ implementation of std::string constructor.""" | |
value = str(ptr) | |
if length is not None: | |
value = value[:length] | |
return value | |
class vector(Generic[T], List[T]): | |
"""C++ implementation of std::vector.""" | |
class iterator: | |
def __init__(self, vector: "std.vector[T]", index: int): | |
self._vector = vector | |
self._index = index | |
self._version = vector._version | |
def _check_version(self): | |
if self._version != self._vector._version: | |
raise RuntimeError("Iterator used after vector was modified.") | |
def __iter__(self): | |
return self | |
def __next__(self) -> T: | |
self._check_version() | |
if self._index >= self._vector.size(): | |
raise StopIteration | |
value = self._vector[self._index] | |
self._index += 1 | |
return value | |
def __add__(self, value: int) -> "std.vector[T].iterator": | |
return self.__class__(self._vector, self._index + value) | |
def __sub__(self, value: int) -> "std.vector[T].iterator": | |
return self.__class__(self._vector, self._index - value) | |
def __init__(self): | |
self._version = 0 | |
def modify(self): | |
# This is a bit of a hack to make sure iterators are invalidated | |
self._version += 1 | |
def push_back(self, value: T) -> None: | |
self.modify() | |
self.append(value) | |
def pop_back(self) -> None: | |
self.modify() | |
if not self.empty(): | |
self.pop() | |
def back(self) -> T: | |
return self[-1] | |
def size(self) -> int: | |
return len(self) | |
def clear(self) -> None: | |
self.modify() | |
super().clear() | |
def empty(self) -> bool: | |
return self.size() == 0 | |
def data(self) -> "std.vector[T]": | |
return self | |
def resize( | |
self, | |
new_size: int, | |
fill_value_factory: Optional[Callable[[], T]] = None, | |
) -> None: | |
if new_size > self.size(): | |
if fill_value_factory is None: | |
raise ValueError("A fill value factory function must be provided.") | |
self.reserve(new_size, fill_value_factory) | |
elif new_size < self.size(): | |
self[:] = self[:new_size] | |
def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None: | |
if capacity > self.size(): | |
fill_value = fill_value_factory() | |
self.extend([fill_value] * (capacity - self.size())) | |
def front(self) -> T: | |
if not self.empty(): | |
return self[0] | |
else: | |
raise IndexError("Vector is empty.") | |
def assign(self, count: int, value: T) -> None: | |
self.clear() | |
self.extend([value] * count) | |
def insert( | |
self, | |
pos: "std.vector[T].iterator", | |
first: "std.vector[T].iterator", | |
last: "std.vector[T].iterator", | |
) -> None: | |
self[pos._index : pos._index] = list( | |
islice(first._vector, first._index, last._index) | |
) | |
def begin(self) -> "std.vector[T].iterator": | |
return self.iterator(self, 0) | |
def end(self) -> "std.vector[T].iterator": | |
return self.iterator(self, self.size()) | |
class map(Generic[T, U], OrderedDict[T, U]): | |
"""C++ implementation of std::map.""" | |
class iterator(Generic[V, W]): | |
def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]): | |
self._map = _map | |
self.iter = iter(_map) | |
self.key = key | |
self._advance() | |
def _sanitize_key(self) -> T: | |
if isinstance(self.key, Sentinel): | |
raise StopIteration | |
return self.key | |
def _advance(self) -> None: | |
try: | |
while next(self.iter) != self.key: | |
pass | |
except StopIteration: | |
self.key = Sentinel() | |
def __next__(self) -> Tuple[T, U]: | |
key = self._sanitize_key() | |
if key in self._map: | |
value = self._map[key] | |
self._advance() | |
return key, value | |
else: | |
raise StopIteration | |
def get(self) -> Tuple[T, U]: | |
key = self._sanitize_key() | |
return key, self._map[key] | |
def first(self) -> T: | |
return self._sanitize_key() | |
def second(self) -> U: | |
return self._map[self._sanitize_key()] | |
def insert( | |
self, key: T, value: U | |
) -> Tuple["std.map[T, U].iterator[T, U]", bool]: | |
if key in self: | |
return self.iterator(self, key), False | |
else: | |
self[key] = value | |
return self.iterator(self, key), True | |
def find(self, key: T) -> "std.map[T, U].iterator[T, U]": | |
if key in self: | |
return self.iterator(self, key) | |
else: | |
return self.end() | |
def at(self, key: T) -> U: | |
if key in self: | |
return self[key] | |
else: | |
raise KeyError("The provided key is not found in the map.") | |
def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None: | |
key = iterator.first | |
if key in self: | |
del self[key] | |
def size(self) -> int: | |
return len(self) | |
def empty(self) -> bool: | |
return self.size() == 0 | |
def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]": | |
try: | |
keys = sorted(list(self.keys())) # type: ignore | |
for k in keys: | |
if k >= key: | |
return self.iterator(self, k) | |
raise ValueError("No key found that is not less than the input key") | |
except TypeError: | |
raise TypeError("Keys of type T cannot be sorted.") | |
def begin(self) -> "std.map[T, U].iterator[T, U]": | |
return self.iterator(self, next(iter(self))) | |
def end(self) -> "std.map[T, U].iterator[T, U]": | |
return self.iterator(self, Sentinel()) | |
# // grammar element type | |
# enum llama_gretype { | |
# // end of rule definition | |
# LLAMA_GRETYPE_END = 0, | |
# // start of alternate definition for rule | |
# LLAMA_GRETYPE_ALT = 1, | |
# // non-terminal element: reference to rule | |
# LLAMA_GRETYPE_RULE_REF = 2, | |
# // terminal element: character (code point) | |
# LLAMA_GRETYPE_CHAR = 3, | |
# // inverse char(s) ([^a], [^a-b] [^abc]) | |
# LLAMA_GRETYPE_CHAR_NOT = 4, | |
# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to | |
# // be an inclusive range ([a-z]) | |
# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, | |
# // modifies a preceding LLAMA_GRETYPE_CHAR or | |
# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |
# LLAMA_GRETYPE_CHAR_ALT = 6, | |
# }; | |
class llama_gretype(Enum): | |
"""grammar element type""" | |
LLAMA_GRETYPE_END = 0 # end of rule definition | |
LLAMA_GRETYPE_ALT = 1 # start of alternate definition for rule | |
LLAMA_GRETYPE_RULE_REF = 2 # non-terminal element: reference to rule | |
LLAMA_GRETYPE_CHAR = 3 # terminal element: character (code point) | |
LLAMA_GRETYPE_CHAR_NOT = 4 # inverse char(s) ([^a], [^a-b] [^abc]) | |
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z]) | |
LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |
# struct parse_state { | |
# std::map<std::string, uint32_t> symbol_ids; | |
# std::vector<std::vector<llama_grammar_element>> rules; | |
# std::vector<const llama_grammar_element *> c_rules(); | |
# }; | |
class parse_state: | |
def __init__(self): | |
self.symbol_ids: std.map[str, int] = std.map() | |
self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector() | |
# std::vector<const llama_grammar_element *> parse_state::c_rules() { | |
# std::vector<const llama_grammar_element *> ret; | |
# for (const auto & rule : rules) { | |
# ret.push_back(rule.data()); | |
# } | |
# return ret; | |
# } | |
def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]: | |
ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]] | |
for rule in self.rules: | |
ret.push_back(rule.data()) | |
return ret | |
def __repr__(self) -> str: | |
return ( | |
f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" | |
) | |
# struct llama_grammar { | |
# const std::vector<std::vector<llama_grammar_element>> rules; | |
# std::vector<std::vector<const llama_grammar_element *>> stacks; | |
# }; | |
# class llama_grammar: | |
# def __init__( | |
# self, | |
# rules: std.vector[std.vector[llama_grammar_element]], | |
# stacks: std.vector[std.vector[llama_grammar_element]], | |
# ): | |
# self.rules = rules | |
# self.stacks = stacks | |
# uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { | |
# uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); | |
# auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); | |
# return result.first->second; | |
# } | |
def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int: | |
next_id = state.symbol_ids.size() # type: int | |
result = state.symbol_ids.insert(std.string(src, len), next_id) | |
return result[0].second # type: ignore | |
# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { | |
# uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); | |
# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; | |
# return next_id; | |
# } | |
def generate_symbol_id(state: parse_state, base_name: str) -> int: | |
next_id = state.symbol_ids.size() # type: int | |
state.symbol_ids[base_name + "_" + str(next_id)] = next_id | |
return next_id | |
# void add_rule( | |
# parse_state & state, | |
# uint32_t rule_id, | |
# const std::vector<llama_grammar_element> & rule) { | |
# if (state.rules.size() <= rule_id) { | |
# state.rules.resize(rule_id + 1); | |
# } | |
# state.rules[rule_id] = rule; | |
# } | |
def add_rule( | |
state: parse_state, | |
rule_id: int, | |
rule: std.vector[LlamaGrammarElement], | |
) -> None: | |
if state.rules.size() <= rule_id: | |
state.rules.resize( | |
rule_id + 1, | |
fill_value_factory=std.vector[LlamaGrammarElement], | |
) | |
state.rules[rule_id] = rule | |
# std::pair<uint32_t, const char *> decode_utf8(const char * src) { | |
# static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; | |
# uint8_t first_byte = static_cast<uint8_t>(*src); | |
# uint8_t highbits = first_byte >> 4; | |
# int len = lookup[highbits]; | |
# uint8_t mask = (1 << (8 - len)) - 1; | |
# uint32_t value = first_byte & mask; | |
# const char * end = src + len; // may overrun! | |
# const char * pos = src + 1; | |
# for ( ; pos < end && *pos; pos++) { | |
# value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); | |
# } | |
# return std::make_pair(value, pos); | |
# } | |
def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: | |
"""Decodes a UTF-8 character from the source string.""" | |
# Get the codepoint of the first character | |
value = ord(src[0]) | |
# Move the pointer ahead one character | |
pos = src + 1 | |
return value, pos | |
# bool is_word_char(char c) { | |
# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); | |
# } | |
def is_word_char(c: str) -> bool: | |
return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") | |
# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { | |
# const char * pos = src; | |
# const char * end = src + size; | |
# uint32_t value = 0; | |
# for ( ; pos < end && *pos; pos++) { | |
# value <<= 4; | |
# char c = *pos; | |
# if ('a' <= c && c <= 'f') { | |
# value += c - 'a' + 10; | |
# } else if ('A' <= c && c <= 'F') { | |
# value += c - 'A' + 10; | |
# } else if ('0' <= c && c <= '9') { | |
# value += c - '0'; | |
# } else { | |
# break; | |
# } | |
# } | |
# if (pos != end) { | |
# throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); | |
# } | |
# return std::make_pair(value, pos); | |
# } | |
def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: | |
pos = const_char_p(src) # type: const_char_p | |
end = src + size # type: const_char_p | |
value = 0 # type: int | |
while pos < end and pos[0]: | |
value <<= 4 | |
c = pos[0] # type: str | |
if "a" <= c <= "f": | |
value += ord(c) - ord("a") + 10 | |
elif "A" <= c <= "F": | |
value += ord(c) - ord("A") + 10 | |
elif "0" <= c <= "9": | |
value += ord(c) - ord("0") | |
else: | |
break | |
pos += 1 | |
if pos != end: | |
raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src)) | |
return (value, pos) | |
# std::pair<uint32_t, const char *> parse_char(const char * src) { | |
# if (*src == '\\') { | |
# switch (src[1]) { | |
# case 'x': return parse_hex(src + 2, 2); | |
# case 'u': return parse_hex(src + 2, 4); | |
# case 'U': return parse_hex(src + 2, 8); | |
# case 't': return std::make_pair('\t', src + 2); | |
# case 'r': return std::make_pair('\r', src + 2); | |
# case 'n': return std::make_pair('\n', src + 2); | |
# case '\\': | |
# case '"': | |
# case '[': | |
# case ']': | |
# return std::make_pair(src[1], src + 2); | |
# default: | |
# throw std::runtime_error(std::string("unknown escape at ") + src); | |
# } | |
# } else if (*src) { | |
# return decode_utf8(src); | |
# } | |
# throw std::runtime_error("unexpected end of input"); | |
# } | |
def parse_char(src: const_char_p) -> Tuple[int, const_char_p]: | |
if src[0] == "\\": | |
case = src[1] # type: str | |
if case == "x": | |
return parse_hex(src + 2, 2) | |
elif case == "u": | |
return parse_hex(src + 2, 4) | |
elif case == "U": | |
return parse_hex(src + 2, 8) | |
elif case == "t": | |
return (ord("\t"), src + 2) # implicit cast | |
elif case == "r": | |
return (ord("\r"), src + 2) # implicit cast | |
elif case == "n": | |
return (ord("\n"), src + 2) # implicit cast | |
elif case in ("\\", '"', "[", "]"): | |
return (ord(case), src + 2) # implicit cast | |
else: | |
raise RuntimeError("unknown escape at " + str(src)) | |
elif src[0]: | |
return decode_utf8(src) | |
else: | |
raise RuntimeError("unexpected end of input") | |
# const char * parse_name(const char * src) { | |
# const char * pos = src; | |
# while (is_word_char(*pos)) { | |
# pos++; | |
# } | |
# if (pos == src) { | |
# throw std::runtime_error(std::string("expecting name at ") + src); | |
# } | |
# return pos; | |
# } | |
def parse_name(src: const_char_p) -> const_char_p: | |
pos = const_char_p(src) # type: const_char_p | |
while is_word_char(pos[0]): | |
pos += 1 | |
if pos == src: | |
raise RuntimeError("expecting name at " + str(src)) | |
return pos | |
# const char * parse_space(const char * src, bool newline_ok) { | |
# const char * pos = src; | |
# while (*pos == ' ' || *pos == '\t' || *pos == '#' || | |
# (newline_ok && (*pos == '\r' || *pos == '\n'))) { | |
# if (*pos == '#') { | |
# while (*pos && *pos != '\r' && *pos != '\n') { | |
# pos++; | |
# } | |
# } else { | |
# pos++; | |
# } | |
# } | |
# return pos; | |
# } | |
def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: | |
pos = const_char_p(src) # type: const_char_p | |
while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")): | |
if pos[0] == "#": | |
while pos[0] is not None and pos[0] not in ("\r", "\n"): | |
pos += 1 | |
else: | |
pos += 1 | |
return pos | |
# const char * parse_sequence( | |
# parse_state & state, | |
# const char * src, | |
# const std::string & rule_name, | |
# std::vector<llama_grammar_element> & out_elements, | |
# bool is_nested) { | |
def parse_sequence( | |
state: parse_state, | |
src: const_char_p, | |
rule_name: str, | |
out_elements: std.vector[LlamaGrammarElement], | |
is_nested: bool, | |
) -> const_char_p: | |
# size_t last_sym_start = out_elements.size(); | |
# const char * pos = src; | |
last_sym_start = out_elements.size() # type: int | |
pos = const_char_p(src) # type: const_char_p | |
# while (*pos) { | |
while pos[0]: | |
# if (*pos == '"') { // literal string | |
# pos++; | |
# last_sym_start = out_elements.size(); | |
# while (*pos != '"') { | |
# auto char_pair = parse_char(pos); | |
# pos = char_pair.second; | |
# out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); | |
# } | |
# pos = parse_space(pos + 1, is_nested); | |
if pos[0] == '"': # literal string | |
pos += 1 | |
last_sym_start = out_elements.size() | |
while pos[0] != '"': | |
char_pair = parse_char(pos) # type: Tuple[int, const_char_p] | |
pos = char_pair[1] | |
out_elements.push_back( | |
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0]) | |
) | |
pos = parse_space(pos + 1, is_nested) | |
# } else if (*pos == '[') { // char range(s) | |
# pos++; | |
# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; | |
elif pos[0] == "[": # char range(s) | |
pos += 1 | |
start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype | |
# if (*pos == '^') { | |
# pos++; | |
# start_type = LLAMA_GRETYPE_CHAR_NOT; | |
# } | |
# last_sym_start = out_elements.size(); | |
if pos[0] == "^": | |
pos += 1 | |
start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT | |
last_sym_start = out_elements.size() | |
# while (*pos != ']') { | |
# auto char_pair = parse_char(pos); | |
# pos = char_pair.second; | |
# enum llama_gretype type = last_sym_start < out_elements.size() | |
# ? LLAMA_GRETYPE_CHAR_ALT | |
# : start_type; | |
# out_elements.push_back({type, char_pair.first}); | |
while pos[0] != "]": | |
char_pair = parse_char(pos) # type: Tuple[int, const_char_p] | |
pos = char_pair[1] | |
type = ( | |
llama_gretype.LLAMA_GRETYPE_CHAR_ALT | |
if last_sym_start < out_elements.size() | |
else start_type | |
) # type: llama_gretype | |
out_elements.push_back(LlamaGrammarElement(type, char_pair[0])) | |
# if (pos[0] == '-' && pos[1] != ']') { | |
# auto endchar_pair = parse_char(pos + 1); | |
# pos = endchar_pair.second; | |
# out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); | |
# } | |
# } | |
if pos[0] == "-" and pos[1] != "]": | |
endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] | |
pos = endchar_pair[1] | |
out_elements.push_back( | |
LlamaGrammarElement( | |
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, | |
endchar_pair[0], | |
) | |
) | |
# pos = parse_space(pos + 1, is_nested); | |
pos = parse_space(pos + 1, is_nested) | |
# } else if (is_word_char(*pos)) { // rule reference | |
# const char * name_end = parse_name(pos); | |
# uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); | |
# pos = parse_space(name_end, is_nested); | |
# last_sym_start = out_elements.size(); | |
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); | |
elif is_word_char(pos[0]): # rule reference | |
name_end = parse_name(pos) # type: const_char_p | |
ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int | |
pos = parse_space(name_end, is_nested) | |
last_sym_start = out_elements.size() | |
out_elements.push_back( | |
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id) | |
) | |
# } else if (*pos == '(') { // grouping | |
# // parse nested alternates into synthesized rule | |
# pos = parse_space(pos + 1, true); | |
# uint32_t sub_rule_id = generate_symbol_id(state, rule_name); | |
# pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); | |
# last_sym_start = out_elements.size(); | |
# // output reference to synthesized rule | |
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); | |
# if (*pos != ')') { | |
# throw std::runtime_error(std::string("expecting ')' at ") + pos); | |
# } | |
# pos = parse_space(pos + 1, is_nested); | |
elif pos[0] == "(": # grouping | |
# parse nested alternates into synthesized rule | |
pos = parse_space(pos + 1, True) | |
sub_rule_id = generate_symbol_id(state, rule_name) # type: int | |
pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) | |
last_sym_start = out_elements.size() | |
# output reference to synthesized rule | |
out_elements.push_back( | |
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) | |
) | |
if pos[0] != ")": | |
raise RuntimeError("expecting ')' at " + str(pos)) | |
pos = parse_space(pos + 1, is_nested) | |
# } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator | |
# if (last_sym_start == out_elements.size()) { | |
# throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); | |
# } | |
elif pos[0] in ("*", "+", "?"): # repetition operator | |
if last_sym_start == out_elements.size(): | |
raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) | |
# // apply transformation to previous symbol (last_sym_start to end) according to | |
# // rewrite rules: | |
# // S* --> S' ::= S S' | | |
# // S+ --> S' ::= S S' | S | |
# // S? --> S' ::= S | | |
# uint32_t sub_rule_id = generate_symbol_id(state, rule_name); | |
# std::vector<llama_grammar_element> sub_rule; | |
# // add preceding symbol to generated rule | |
# sub_rule.insert( | |
# sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); | |
sub_rule_id = generate_symbol_id(state, rule_name) # type: int | |
sub_rule = std.vector[ | |
LlamaGrammarElement | |
]() # type: std.vector[LlamaGrammarElement] | |
sub_rule.insert( | |
sub_rule.end(), | |
out_elements.begin() + last_sym_start, | |
out_elements.end(), | |
) | |
# if (*pos == '*' || *pos == '+') { | |
# // cause generated rule to recurse | |
# sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); | |
# } | |
# // mark start of alternate def | |
# sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); | |
if pos[0] in ("*", "+"): | |
sub_rule.push_back( | |
LlamaGrammarElement( | |
llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id | |
) | |
) | |
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) | |
# if (*pos == '+') { | |
# // add preceding symbol as alternate only for '+' (otherwise empty) | |
# sub_rule.insert( | |
# sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); | |
# } | |
# sub_rule.push_back({LLAMA_GRETYPE_END, 0}); | |
# add_rule(state, sub_rule_id, sub_rule); | |
# // in original rule, replace previous symbol with reference to generated rule | |
# out_elements.resize(last_sym_start); | |
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); | |
# pos = parse_space(pos + 1, is_nested); | |
if pos[0] == "+": | |
# add preceding symbol as alternate only for '+' (otherwise empty) | |
sub_rule.insert( | |
sub_rule.end(), | |
out_elements.begin() + last_sym_start, | |
out_elements.end(), | |
) | |
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) | |
add_rule(state, sub_rule_id, sub_rule) | |
# in original rule, replace previous symbol with reference to generated rule | |
out_elements.resize(last_sym_start) | |
out_elements.push_back( | |
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) | |
) | |
pos = parse_space(pos + 1, is_nested) | |
# } else { | |
# break; | |
# } | |
else: | |
break | |
# } | |
# return pos; | |
# } | |
return pos | |
# const char * parse_alternates( | |
# parse_state & state, | |
# const char * src, | |
# const std::string & rule_name, | |
# uint32_t rule_id, | |
# bool is_nested) { | |
# std::vector<llama_grammar_element> rule; | |
# const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); | |
# while (*pos == '|') { | |
# rule.push_back({LLAMA_GRETYPE_ALT, 0}); | |
# pos = parse_space(pos + 1, true); | |
# pos = parse_sequence(state, pos, rule_name, rule, is_nested); | |
# } | |
# rule.push_back({LLAMA_GRETYPE_END, 0}); | |
# add_rule(state, rule_id, rule); | |
# return pos; | |
# } | |
def parse_alternates( | |
state: parse_state, | |
src: const_char_p, | |
rule_name: str, | |
rule_id: int, | |
is_nested: bool, | |
) -> const_char_p: | |
rule = std.vector() # type: std.vector[LlamaGrammarElement] | |
pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p | |
while pos[0] == "|": | |
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) | |
pos = parse_space(pos + 1, True) | |
pos = parse_sequence(state, pos, rule_name, rule, is_nested) | |
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) | |
add_rule(state, rule_id, rule) | |
return pos | |
# const char * parse_rule(parse_state & state, const char * src) { | |
# const char * name_end = parse_name(src); | |
# const char * pos = parse_space(name_end, false); | |
# size_t name_len = name_end - src; | |
# uint32_t rule_id = get_symbol_id(state, src, name_len); | |
# const std::string name(src, name_len); | |
# if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { | |
# throw std::runtime_error(std::string("expecting ::= at ") + pos); | |
# } | |
# pos = parse_space(pos + 3, true); | |
# pos = parse_alternates(state, pos, name, rule_id, false); | |
# if (*pos == '\r') { | |
# pos += pos[1] == '\n' ? 2 : 1; | |
# } else if (*pos == '\n') { | |
# pos++; | |
# } else if (*pos) { | |
# throw std::runtime_error(std::string("expecting newline or end at ") + pos); | |
# } | |
# return parse_space(pos, true); | |
# } | |
def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: | |
name_end = parse_name(src) # type: const_char_p | |
pos = parse_space(name_end, False) # type: const_char_p | |
name_len = name_end - src # type: int | |
rule_id = get_symbol_id(state, src, name_len) # type: int | |
name = std.string(src, name_len) # type: str | |
if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="): | |
raise RuntimeError("expecting ::= at " + str(pos)) | |
pos = parse_space(pos + 3, True) # type: const_char_p | |
pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p | |
if pos[0] == "\r": | |
pos += 2 if pos[1] == "\n" else 1 | |
elif pos[0] == "\n": | |
pos += 1 | |
elif pos[0]: | |
raise RuntimeError("expecting newline or end at " + str(pos)) | |
return parse_space(pos, True) | |
# parse_state parse(const char * src) { | |
# try { | |
# parse_state state; | |
# const char * pos = parse_space(src, true); | |
# while (*pos) { | |
# pos = parse_rule(state, pos); | |
# } | |
# return state; | |
# } catch (const std::exception & err) { | |
# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); | |
# return parse_state(); | |
# } | |
# } | |
def parse(src: const_char_p) -> parse_state: | |
try: | |
state = parse_state() # type: parse_state | |
pos = parse_space(src, True) # type: const_char_p | |
while pos[0]: | |
pos = parse_rule(state, pos) | |
return state | |
except Exception as err: | |
print(f"{parse.__name__}: error parsing grammar: {err}") | |
return parse_state() | |
# void print_grammar_char(FILE * file, uint32_t c) { | |
# if (0x20 <= c && c <= 0x7f) { | |
# fprintf(file, "%c", static_cast<char>(c)); | |
# } else { | |
# // cop out of encoding UTF-8 | |
# fprintf(file, "<U+%04X>", c); | |
# } | |
# } | |
def print_grammar_char(file: TextIO, c: int) -> None: | |
if 0x20 <= c and c <= 0x7F: | |
file.write(chr(c)) | |
else: | |
# cop out of encoding UTF-8 | |
file.write(f"<U+{c:04X}>") | |
# bool is_char_element(llama_grammar_element elem) { | |
# switch (elem.type) { | |
# case LLAMA_GRETYPE_CHAR: return true; | |
# case LLAMA_GRETYPE_CHAR_NOT: return true; | |
# case LLAMA_GRETYPE_CHAR_ALT: return true; | |
# case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; | |
# default: return false; | |
# } | |
# } | |
def is_char_element(elem: LlamaGrammarElement) -> bool: | |
return elem.type in ( | |
llama_gretype.LLAMA_GRETYPE_CHAR, | |
llama_gretype.LLAMA_GRETYPE_CHAR_NOT, | |
llama_gretype.LLAMA_GRETYPE_CHAR_ALT, | |
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, | |
) | |
# void print_rule( | |
# FILE * file, | |
# uint32_t rule_id, | |
# const std::vector<llama_grammar_element> & rule, | |
# const std::map<uint32_t, std::string> & symbol_id_names) { | |
def print_rule( | |
file: TextIO, | |
rule_id: int, | |
rule: std.vector[LlamaGrammarElement], | |
symbol_id_names: std.map[int, str], | |
) -> None: | |
# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { | |
# throw std::runtime_error( | |
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); | |
# } | |
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); | |
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: | |
raise RuntimeError( | |
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) | |
) | |
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") | |
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) { | |
# llama_grammar_element elem = rule[i]; | |
# switch (elem.type) { | |
# case LLAMA_GRETYPE_END: | |
# throw std::runtime_error( | |
# "unexpected end of rule: " + std::to_string(rule_id) + "," + | |
# std::to_string(i)); | |
# case LLAMA_GRETYPE_ALT: | |
# fprintf(file, "| "); | |
# break; | |
# case LLAMA_GRETYPE_RULE_REF: | |
# fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); | |
# break; | |
# case LLAMA_GRETYPE_CHAR: | |
# fprintf(file, "["); | |
# print_grammar_char(file, elem.value); | |
# break; | |
# case LLAMA_GRETYPE_CHAR_NOT: | |
# fprintf(file, "[^"); | |
# print_grammar_char(file, elem.value); | |
# break; | |
# case LLAMA_GRETYPE_CHAR_RNG_UPPER: | |
# if (i == 0 || !is_char_element(rule[i - 1])) { | |
# throw std::runtime_error( | |
# "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + | |
# std::to_string(rule_id) + "," + std::to_string(i)); | |
# } | |
# fprintf(file, "-"); | |
# print_grammar_char(file, elem.value); | |
# break; | |
# case LLAMA_GRETYPE_CHAR_ALT: | |
# if (i == 0 || !is_char_element(rule[i - 1])) { | |
# throw std::runtime_error( | |
# "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + | |
# std::to_string(rule_id) + "," + std::to_string(i)); | |
# } | |
# print_grammar_char(file, elem.value); | |
# break; | |
# } | |
for i, elem in enumerate(rule[:-1]): | |
case = elem.type # type: llama_gretype | |
if case is llama_gretype.LLAMA_GRETYPE_END: | |
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) | |
elif case is llama_gretype.LLAMA_GRETYPE_ALT: | |
print("| ", file=file, end="") | |
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: | |
print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") | |
elif case is llama_gretype.LLAMA_GRETYPE_CHAR: | |
print("[", file=file, end="") | |
print_grammar_char(file, elem.value) | |
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT: | |
print("[^", file=file, end="") | |
print_grammar_char(file, elem.value) | |
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: | |
if i == 0 or not is_char_element(rule[i - 1]): | |
raise RuntimeError( | |
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " | |
+ str(rule_id) | |
+ "," | |
+ str(i) | |
) | |
print("-", file=file, end="") | |
print_grammar_char(file, elem.value) | |
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT: | |
if i == 0 or not is_char_element(rule[i - 1]): | |
raise RuntimeError( | |
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " | |
+ str(rule_id) | |
+ "," | |
+ str(i) | |
) | |
print_grammar_char(file, elem.value) | |
# if (is_char_element(elem)) { | |
# switch (rule[i + 1].type) { | |
# case LLAMA_GRETYPE_CHAR_ALT: | |
# case LLAMA_GRETYPE_CHAR_RNG_UPPER: | |
# break; | |
# default: | |
# fprintf(file, "] "); | |
if is_char_element(elem): | |
if rule[i + 1].type in ( | |
llama_gretype.LLAMA_GRETYPE_CHAR_ALT, | |
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, | |
): | |
pass | |
else: | |
print("] ", file=file, end="") | |
# } | |
# } | |
# } | |
# fprintf(file, "\n"); | |
# } | |
print(file=file) | |
# void print_grammar(FILE * file, const parse_state & state) { | |
# try { | |
# std::map<uint32_t, std::string> symbol_id_names; | |
# for (auto kv : state.symbol_ids) { | |
# symbol_id_names[kv.second] = kv.first; | |
# } | |
# for (size_t i = 0, end = state.rules.size(); i < end; i++) { | |
# // fprintf(file, "%zu: ", i); | |
# // print_rule_binary(file, state.rules[i]); | |
# print_rule(file, i, state.rules[i], symbol_id_names); | |
# // fprintf(file, "\n"); | |
# } | |
# } catch (const std::exception & err) { | |
# fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); | |
# } | |
# } | |
def print_grammar(file: TextIO, state: parse_state) -> None: | |
try: | |
symbol_id_names = std.map() # type: std.map[int, str] | |
for kv in state.symbol_ids.items(): | |
symbol_id_names[kv[1]] = kv[0] | |
for i, rule in enumerate(state.rules): | |
print_rule(file, i, rule, symbol_id_names) | |
except Exception as err: | |
print( | |
f"{print_grammar.__name__}: error printing grammar: {err}", | |
file=sys.stderr, | |
) | |
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars""" | |
ARITHMETIC_GBNF = r""" | |
root ::= (expr "=" ws term "\n")+ | |
expr ::= term ([-+*/] term)* | |
term ::= ident | num | "(" ws expr ")" ws | |
ident ::= [a-z] [a-z0-9_]* ws | |
num ::= [0-9]+ ws | |
ws ::= [ \t\n]* | |
""" | |
C_GBNF = r""" | |
root ::= (declaration)* | |
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" | |
dataType ::= "int" ws | "float" ws | "char" ws | |
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* | |
parameter ::= dataType identifier | |
statement ::= | |
( dataType identifier ws "=" ws expression ";" ) | | |
( identifier ws "=" ws expression ";" ) | | |
( identifier ws "(" argList? ")" ";" ) | | |
( "return" ws expression ";" ) | | |
( "while" "(" condition ")" "{" statement* "}" ) | | |
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | | |
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | | |
( singleLineComment ) | | |
( multiLineComment ) | |
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression | |
forUpdate ::= identifier ws "=" ws expression | |
condition ::= expression relationOperator expression | |
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") | |
expression ::= term (("+" | "-") term)* | |
term ::= factor(("*" | "/") factor)* | |
factor ::= identifier | number | unaryTerm | funcCall | parenExpression | |
unaryTerm ::= "-" factor | |
funcCall ::= identifier "(" argList? ")" | |
parenExpression ::= "(" ws expression ws ")" | |
argList ::= expression ("," ws expression)* | |
number ::= [0-9]+ | |
singleLineComment ::= "//" [^\n]* "\n" | |
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" | |
ws ::= ([ \t\n]+) | |
""" | |
CHESS_GBNF = r""" | |
root ::= object | |
value ::= object | array | string | number | ("true" | "false" | "null") ws | |
object ::= | |
"{" ws ( | |
string ":" ws value | |
("," ws string ":" ws value)* | |
)? "}" ws | |
array ::= | |
"[" ws ( | |
value | |
("," ws value)* | |
)? "]" ws | |
string ::= | |
"\"" ( | |
[^"\\] | | |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes | |
)* "\"" ws | |
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws | |
# Optional space: by convention, applied in this grammar after literal chars when allowed | |
ws ::= ([ \t\n] ws)? | |
""" | |
JAPANESE_GBNF = r""" | |
root ::= object | |
value ::= object | array | string | number | ("true" | "false" | "null") ws | |
object ::= | |
"{" ws ( | |
string ":" ws value | |
("," ws string ":" ws value)* | |
)? "}" ws | |
array ::= | |
"[" ws ( | |
value | |
("," ws value)* | |
)? "]" ws | |
string ::= | |
"\"" ( | |
[^"\\] | | |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes | |
)* "\"" ws | |
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws | |
# Optional space: by convention, applied in this grammar after literal chars when allowed | |
ws ::= ([ \t\n] ws)? | |
""" | |
JSON_ARR_GBNF = r""" | |
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array | |
# Useful for generating JSON arrays | |
root ::= arr | |
value ::= object | array | string | number | ("true" | "false" | "null") ws | |
arr ::= | |
"[\n" ws ( | |
value | |
(",\n" ws value)* | |
)? "]" | |
object ::= | |
"{" ws ( | |
string ":" ws value | |
("," ws string ":" ws value)* | |
)? "}" ws | |
array ::= | |
"[" ws ( | |
value | |
("," ws value)* | |
)? "]" ws | |
string ::= | |
"\"" ( | |
[^"\\\x7F\x00-\x1F] | | |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes | |
)* "\"" ws | |
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws | |
# Optional space: by convention, applied in this grammar after literal chars when allowed | |
ws ::= ([ \t\n] ws)? | |
""" | |
JSON_GBNF = r""" | |
root ::= object | |
value ::= object | array | string | number | ("true" | "false" | "null") ws | |
object ::= | |
"{" ws ( | |
string ":" ws value | |
("," ws string ":" ws value)* | |
)? "}" ws | |
array ::= | |
"[" ws ( | |
value | |
("," ws value)* | |
)? "]" ws | |
string ::= | |
"\"" ( | |
[^"\\\x7F\x00-\x1F] | | |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes | |
)* "\"" ws | |
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws | |
ws ::= ([ \t\n] ws)? | |
""" | |
LIST_GBNF = r""" | |
root ::= item+ | |
# Excludes various line break characters | |
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" | |
""" | |
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" | |
import json | |
import re | |
from typing import List, Optional | |
# whitespace is constrained to a single space char to prevent model "running away" in | |
# whitespace. Also maybe improves generation quality? | |
SPACE_RULE = '" "?' | |
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") | |
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') | |
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} | |
# whitespace is constrained to a single space char to prevent model "running away" in | |
# whitespace. Also maybe improves generation quality? | |
SPACE_RULE = '" "?' | |
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): | |
if not separator_rule: | |
if min_items == 0 and max_items == 1: | |
return f'{item_rule}?' | |
elif min_items == 1 and max_items is None: | |
return f'{item_rule}+' | |
result = '' | |
if min_items > 0: | |
if item_rule_is_literal and separator_rule is None: | |
result = '"' + (item_rule[1:-1] * min_items) + '"' | |
else: | |
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) | |
def opt_repetitions(up_to_n, prefix_with_sep=False): | |
''' | |
- n=4, no sep: '(a (a (a (a)?)?)?)?' | |
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' | |
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' | |
''' | |
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule | |
if up_to_n == 0: | |
return '' | |
elif up_to_n == 1: | |
return f'({content})?' | |
elif separator_rule and not prefix_with_sep: | |
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' | |
else: | |
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) | |
if min_items > 0 and max_items != min_items: | |
result += ' ' | |
if max_items is not None: | |
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) | |
else: | |
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' | |
if min_items == 0 and separator_rule: | |
result = f'({item_rule} {item_operator}*)?' | |
else: | |
result += f'{item_operator}*' | |
return result | |
class BuiltinRule: | |
def __init__(self, content: str, deps: list = None): | |
self.content = content | |
self.deps = deps or [] | |
_up_to_15_digits = _build_repetition('[0-9]', 0, 15) | |
PRIMITIVE_RULES = { | |
'boolean' : BuiltinRule('("true" | "false") space', []), | |
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), | |
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), | |
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), | |
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), | |
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), | |
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), | |
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), | |
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), | |
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), | |
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), | |
'null' : BuiltinRule('"null" space', []), | |
} | |
# TODO: support "uri", "email" string formats | |
STRING_FORMAT_RULES = { | |
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), | |
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), | |
'date-time' : BuiltinRule('date "T" time', ['date', 'time']), | |
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), | |
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), | |
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), | |
} | |
DOTALL = '[\\U00000000-\\U0010FFFF]' | |
DOT = '[^\\x0A\\x0D]' | |
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) | |
NON_LITERAL_SET = set('|.()[]{}*+?') | |
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') | |
class SchemaConverter: | |
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): | |
self._prop_order = prop_order | |
self._allow_fetch = allow_fetch | |
self._dotall = dotall | |
self._raw_pattern = raw_pattern | |
self._rules = { | |
'space': SPACE_RULE, | |
} | |
self._refs = {} | |
self._refs_being_resolved = set() | |
def _format_literal(self, literal): | |
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( | |
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal | |
) | |
return f'"{escaped}"' | |
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: | |
''' | |
not_literal('a') -> '[^a]' | |
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' | |
''' | |
assert len(literal) > 0, 'Empty literal not supported' | |
def recurse(i: int): | |
c = literal[i] | |
if maybe_escaped_underscores and c == '_': | |
yield f'[^{c}\\\\]' | |
yield ' | ' | |
yield f'"\\\\"? "{c}"' | |
else: | |
yield f'[^{c}]' | |
if i < len(literal) - 1: | |
yield ' | ' | |
yield self._format_literal(c) | |
yield ' (' | |
yield from recurse(i + 1) | |
yield ')?' | |
return ''.join(('(', *recurse(0), ')')) | |
def _add_rule(self, name, rule): | |
esc_name = INVALID_RULE_CHARS_RE.sub('-', name) | |
if esc_name not in self._rules or self._rules[esc_name] == rule: | |
key = esc_name | |
else: | |
i = 0 | |
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: | |
i += 1 | |
key = f'{esc_name}{i}' | |
self._rules[key] = rule | |
return key | |
def resolve_refs(self, schema: dict, url: str): | |
''' | |
Resolves all $ref fields in the given schema, fetching any remote schemas, | |
replacing $ref with absolute reference URL and populating self._refs with the | |
respective referenced (sub)schema dictionaries. | |
''' | |
def visit(n: dict): | |
if isinstance(n, list): | |
return [visit(x) for x in n] | |
elif isinstance(n, dict): | |
ref = n.get('$ref') | |
if ref is not None and ref not in self._refs: | |
if ref.startswith('https://'): | |
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' | |
import requests | |
frag_split = ref.split('#') | |
base_url = frag_split[0] | |
target = self._refs.get(base_url) | |
if target is None: | |
target = self.resolve_refs(requests.get(ref).json(), base_url) | |
self._refs[base_url] = target | |
if len(frag_split) == 1 or frag_split[-1] == '': | |
return target | |
elif ref.startswith('#/'): | |
target = schema | |
ref = f'{url}{ref}' | |
n['$ref'] = ref | |
else: | |
raise ValueError(f'Unsupported ref {ref}') | |
for sel in ref.split('#')[-1].split('/')[1:]: | |
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' | |
target = target[sel] | |
self._refs[ref] = target | |
else: | |
for v in n.values(): | |
visit(v) | |
return n | |
return visit(schema) | |
def _generate_union_rule(self, name, alt_schemas): | |
return ' | '.join(( | |
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') | |
for i, alt_schema in enumerate(alt_schemas) | |
)) | |
def _visit_pattern(self, pattern, name): | |
''' | |
Transforms a regular expression pattern into a GBNF rule. | |
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions | |
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md | |
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. | |
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which | |
we define sub-rules to keep the output lean. | |
''' | |
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' | |
pattern = pattern[1:-1] | |
sub_rule_ids = {} | |
i = 0 | |
length = len(pattern) | |
def to_rule(s: Tuple[str, bool]) -> str: | |
(txt, is_literal) = s | |
return "\"" + txt + "\"" if is_literal else txt | |
def transform() -> Tuple[str, bool]: | |
''' | |
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. | |
''' | |
nonlocal i | |
nonlocal pattern | |
nonlocal sub_rule_ids | |
start = i | |
# For each component of this sequence, store its string representation and whether it's a literal. | |
# We only need a flat structure here to apply repetition operators to the last item, and | |
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially | |
# (GBNF's syntax is luckily very close to regular expressions!) | |
seq: list[Tuple[str, bool]] = [] | |
def get_dot(): | |
if self._dotall: | |
rule = DOTALL | |
else: | |
# Accept any character... except \n and \r line break chars (\x0A and \xOD) | |
rule = DOT | |
return self._add_rule(f'dot', rule) | |
def join_seq(): | |
nonlocal seq | |
ret = [] | |
for is_literal, g in groupby(seq, lambda x: x[1]): | |
if is_literal: | |
ret.append((''.join(x[0] for x in g), True)) | |
else: | |
ret.extend(g) | |
if len(ret) == 1: | |
return ret[0] | |
return (' '.join(to_rule(x) for x in seq), False) | |
while i < length: | |
c = pattern[i] | |
if c == '.': | |
seq.append((get_dot(), False)) | |
i += 1 | |
elif c == '(': | |
i += 1 | |
if i < length: | |
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' | |
seq.append((f'({to_rule(transform())})', False)) | |
elif c == ')': | |
i += 1 | |
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' | |
return join_seq() | |
elif c == '[': | |
square_brackets = c | |
i += 1 | |
while i < length and pattern[i] != ']': | |
if pattern[i] == '\\': | |
square_brackets += pattern[i:i+2] | |
i += 2 | |
else: | |
square_brackets += pattern[i] | |
i += 1 | |
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' | |
square_brackets += ']' | |
i += 1 | |
seq.append((square_brackets, False)) | |
elif c == '|': | |
seq.append(('|', False)) | |
i += 1 | |
elif c in ('*', '+', '?'): | |
seq[-1] = (to_rule(seq[-1]) + c, False) | |
i += 1 | |
elif c == '{': | |
curly_brackets = c | |
i += 1 | |
while i < length and pattern[i] != '}': | |
curly_brackets += pattern[i] | |
i += 1 | |
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' | |
curly_brackets += '}' | |
i += 1 | |
nums = [s.strip() for s in curly_brackets[1:-1].split(',')] | |
min_times = 0 | |
max_times = None | |
try: | |
if len(nums) == 1: | |
min_times = int(nums[0]) | |
max_times = min_times | |
else: | |
assert len(nums) == 2 | |
min_times = int(nums[0]) if nums[0] else 0 | |
max_times = int(nums[1]) if nums[1] else None | |
except ValueError: | |
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') | |
(sub, sub_is_literal) = seq[-1] | |
if not sub_is_literal: | |
id = sub_rule_ids.get(sub) | |
if id is None: | |
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) | |
sub_rule_ids[sub] = id | |
sub = id | |
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) | |
else: | |
literal = '' | |
while i < length: | |
if pattern[i] == '\\' and i < length - 1: | |
next = pattern[i + 1] | |
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: | |
i += 1 | |
literal += pattern[i] | |
i += 1 | |
else: | |
literal += pattern[i:i+2] | |
i += 2 | |
elif pattern[i] == '"' and not self._raw_pattern: | |
literal += '\\"' | |
i += 1 | |
elif pattern[i] not in NON_LITERAL_SET and \ | |
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): | |
literal += pattern[i] | |
i += 1 | |
else: | |
break | |
if literal: | |
seq.append((literal, True)) | |
return join_seq() | |
return self._add_rule( | |
name, | |
to_rule(transform()) if self._raw_pattern \ | |
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") | |
def _resolve_ref(self, ref): | |
ref_name = ref.split('/')[-1] | |
if ref_name not in self._rules and ref not in self._refs_being_resolved: | |
self._refs_being_resolved.add(ref) | |
resolved = self._refs[ref] | |
ref_name = self.visit(resolved, ref_name) | |
self._refs_being_resolved.remove(ref) | |
return ref_name | |
def _generate_constant_rule(self, value): | |
return self._format_literal(json.dumps(value)) | |
def visit(self, schema, name): | |
schema_type = schema.get('type') | |
schema_format = schema.get('format') | |
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' | |
if (ref := schema.get('$ref')) is not None: | |
return self._add_rule(rule_name, self._resolve_ref(ref)) | |
elif 'oneOf' in schema or 'anyOf' in schema: | |
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) | |
elif isinstance(schema_type, list): | |
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) | |
elif 'const' in schema: | |
return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) | |
elif 'enum' in schema: | |
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) | |
return self._add_rule(rule_name, rule) | |
elif schema_type in (None, 'object') and \ | |
('properties' in schema or \ | |
('additionalProperties' in schema and schema['additionalProperties'] is not True)): | |
required = set(schema.get('required', [])) | |
properties = list(schema.get('properties', {}).items()) | |
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) | |
elif schema_type in (None, 'object') and 'allOf' in schema: | |
required = set() | |
properties = [] | |
hybrid_name = name | |
def add_component(comp_schema, is_required): | |
if (ref := comp_schema.get('$ref')) is not None: | |
comp_schema = self._refs[ref] | |
if 'properties' in comp_schema: | |
for prop_name, prop_schema in comp_schema['properties'].items(): | |
properties.append((prop_name, prop_schema)) | |
if is_required: | |
required.add(prop_name) | |
for t in schema['allOf']: | |
if 'anyOf' in t: | |
for tt in t['anyOf']: | |
add_component(tt, is_required=False) | |
else: | |
add_component(t, is_required=True) | |
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) | |
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): | |
items = schema.get('items') or schema['prefixItems'] | |
if isinstance(items, list): | |
return self._add_rule( | |
rule_name, | |
'"[" space ' + | |
' "," space '.join( | |
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') | |
for i, item in enumerate(items)) + | |
' "]" space') | |
else: | |
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') | |
min_items = schema.get("minItems", 0) | |
max_items = schema.get("maxItems") | |
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') | |
elif schema_type in (None, 'string') and 'pattern' in schema: | |
return self._visit_pattern(schema['pattern'], rule_name) | |
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): | |
return self._add_primitive( | |
'root' if rule_name == 'root' else schema_format, | |
PRIMITIVE_RULES['uuid'] | |
) | |
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: | |
prim_name = f'{schema_format}-string' | |
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) | |
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): | |
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) | |
min_len = schema.get('minLength', 0) | |
max_len = schema.get('maxLength') | |
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') | |
elif (schema_type == 'object') or (len(schema) == 0): | |
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) | |
else: | |
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' | |
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero | |
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) | |
def _add_primitive(self, name: str, rule: BuiltinRule): | |
n = self._add_rule(name, rule.content) | |
for dep in rule.deps: | |
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) | |
assert dep_rule, f'Rule {dep} not known' | |
if dep not in self._rules: | |
self._add_primitive(dep, dep_rule) | |
return n | |
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): | |
prop_order = self._prop_order | |
# sort by position in prop_order (if specified) then by original order | |
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] | |
prop_kv_rule_names = {} | |
for prop_name, prop_schema in properties: | |
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') | |
prop_kv_rule_names[prop_name] = self._add_rule( | |
f'{name}{"-" if name else ""}{prop_name}-kv', | |
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' | |
) | |
required_props = [k for k in sorted_props if k in required] | |
optional_props = [k for k in sorted_props if k not in required] | |
if additional_properties == True or isinstance(additional_properties, dict): | |
sub_name = f'{name}{"-" if name else ""}additional' | |
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') | |
prop_kv_rule_names["*"] = self._add_rule( | |
f'{sub_name}-kv', | |
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' | |
) | |
optional_props.append("*") | |
rule = '"{" space ' | |
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) | |
if optional_props: | |
rule += ' (' | |
if required_props: | |
rule += ' "," space ( ' | |
def get_recursive_refs(ks, first_is_optional): | |
[k, *rest] = ks | |
kv_rule_name = prop_kv_rule_names[k] | |
if k == '*': | |
res = self._add_rule( | |
f'{name}{"-" if name else ""}additional-kvs', | |
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' | |
) | |
elif first_is_optional: | |
res = f'( "," space {kv_rule_name} )?' | |
else: | |
res = kv_rule_name | |
if len(rest) > 0: | |
res += ' ' + self._add_rule( | |
f'{name}{"-" if name else ""}{k}-rest', | |
get_recursive_refs(rest, first_is_optional=True) | |
) | |
return res | |
rule += ' | '.join( | |
get_recursive_refs(optional_props[i:], first_is_optional=False) | |
for i in range(len(optional_props)) | |
) | |
if required_props: | |
rule += ' )' | |
rule += ' )?' | |
rule += ' "}" space' | |
return rule | |
def format_grammar(self): | |
return '\n'.join( | |
f'{name} ::= {rule}' | |
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) | |
) | |
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): | |
prop_order = prop_order or [] | |
schema = json.loads(schema) | |
prop_order = {name: idx for idx, name in enumerate(prop_order)} | |
converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False) | |
schema = converter.resolve_refs(schema, "stdin") | |
converter.visit(schema, "") | |
return converter.format_grammar() | |