| | |
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import itertools |
| | import json |
| | import re |
| | import sys |
| | from typing import Any, List, Optional, Set, Tuple, Union |
| |
|
| | def _build_repetition(item_rule, min_items, max_items, separator_rule=None): |
| |
|
| | if max_items == 0: |
| | return "" |
| |
|
| | if min_items == 0 and max_items == 1: |
| | return f'{item_rule}?' |
| |
|
| | if not separator_rule: |
| | if min_items == 1 and max_items is None: |
| | return f'{item_rule}+' |
| | elif min_items == 0 and max_items is None: |
| | return f'{item_rule}*' |
| | else: |
| | return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}' |
| |
|
| | result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) |
| | return f'({result})?' if min_items == 0 else result |
| |
|
| | def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): |
| | has_min = min_value != None |
| | has_max = max_value != None |
| |
|
| | def digit_range(from_char: str, to_char: str): |
| | out.append("[") |
| | if from_char == to_char: |
| | out.append(from_char) |
| | else: |
| | out.append(from_char) |
| | out.append("-") |
| | out.append(to_char) |
| | out.append("]") |
| |
|
| | def more_digits(min_digits: int, max_digits: int): |
| | out.append("[0-9]") |
| | if min_digits == max_digits and min_digits == 1: |
| | return |
| | out.append("{") |
| | out.append(str(min_digits)) |
| | if max_digits != min_digits: |
| | out.append(",") |
| | if max_digits != sys.maxsize: |
| | out.append(str(max_digits)) |
| | out.append("}") |
| |
|
| | def uniform_range(from_str: str, to_str: str): |
| | i = 0 |
| | while i < len(from_str) and from_str[i] == to_str[i]: |
| | i += 1 |
| | if i > 0: |
| | out.append("\"") |
| | out.append(from_str[:i]) |
| | out.append("\"") |
| | if i < len(from_str): |
| | if i > 0: |
| | out.append(" ") |
| | sub_len = len(from_str) - i - 1 |
| | if sub_len > 0: |
| | from_sub = from_str[i+1:] |
| | to_sub = to_str[i+1:] |
| | sub_zeros = "0" * sub_len |
| | sub_nines = "9" * sub_len |
| |
|
| | to_reached = False |
| | out.append("(") |
| | if from_sub == sub_zeros: |
| | digit_range(from_str[i], chr(ord(to_str[i]) - 1)) |
| | out.append(" ") |
| | more_digits(sub_len, sub_len) |
| | else: |
| | out.append("[") |
| | out.append(from_str[i]) |
| | out.append("] ") |
| | out.append("(") |
| | uniform_range(from_sub, sub_nines) |
| | out.append(")") |
| | if ord(from_str[i]) < ord(to_str[i]) - 1: |
| | out.append(" | ") |
| | if to_sub == sub_nines: |
| | digit_range(chr(ord(from_str[i]) + 1), to_str[i]) |
| | to_reached = True |
| | else: |
| | digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1)) |
| | out.append(" ") |
| | more_digits(sub_len, sub_len) |
| | if not to_reached: |
| | out.append(" | ") |
| | digit_range(to_str[i], to_str[i]) |
| | out.append(" ") |
| | uniform_range(sub_zeros, to_sub) |
| | out.append(")") |
| | else: |
| | out.append("[") |
| | out.append(from_str[i]) |
| | out.append("-") |
| | out.append(to_str[i]) |
| | out.append("]") |
| |
|
| | if has_min and has_max: |
| | if min_value < 0 and max_value < 0: |
| | out.append("\"-\" (") |
| | _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) |
| | out.append(")") |
| | return |
| |
|
| | if min_value < 0: |
| | out.append("\"-\" (") |
| | _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True) |
| | out.append(") | ") |
| | min_value = 0 |
| |
|
| | min_s = str(min_value) |
| | max_s = str(max_value) |
| | min_digits = len(min_s) |
| | max_digits = len(max_s) |
| |
|
| | for digits in range(min_digits, max_digits): |
| | uniform_range(min_s, "9" * digits) |
| | min_s = "1" + "0" * digits |
| | out.append(" | ") |
| | uniform_range(min_s, max_s) |
| | return |
| |
|
| | less_decimals = max(decimals_left - 1, 1) |
| |
|
| | if has_min: |
| | if min_value < 0: |
| | out.append("\"-\" (") |
| | _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) |
| | out.append(") | [0] | [1-9] ") |
| | more_digits(0, decimals_left - 1) |
| | elif min_value == 0: |
| | if top_level: |
| | out.append("[0] | [1-9] ") |
| | more_digits(0, less_decimals) |
| | else: |
| | more_digits(1, decimals_left) |
| | elif min_value <= 9: |
| | c = str(min_value) |
| | range_start = '1' if top_level else '0' |
| | if c > range_start: |
| | digit_range(range_start, chr(ord(c) - 1)) |
| | out.append(" ") |
| | more_digits(1, less_decimals) |
| | out.append(" | ") |
| | digit_range(c, "9") |
| | out.append(" ") |
| | more_digits(0, less_decimals) |
| | else: |
| | min_s = str(min_value) |
| | length = len(min_s) |
| | c = min_s[0] |
| |
|
| | if c > "1": |
| | digit_range("1" if top_level else "0", chr(ord(c) - 1)) |
| | out.append(" ") |
| | more_digits(length, less_decimals) |
| | out.append(" | ") |
| | digit_range(c, c) |
| | out.append(" (") |
| | _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False) |
| | out.append(")") |
| | if c < "9": |
| | out.append(" | ") |
| | digit_range(chr(ord(c) + 1), "9") |
| | out.append(" ") |
| | more_digits(length - 1, less_decimals) |
| | return |
| |
|
| | if has_max: |
| | if max_value >= 0: |
| | if top_level: |
| | out.append("\"-\" [1-9] ") |
| | more_digits(0, less_decimals) |
| | out.append(" | ") |
| | _generate_min_max_int(0, max_value, out, decimals_left, top_level=True) |
| | else: |
| | out.append("\"-\" (") |
| | _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False) |
| | out.append(")") |
| | return |
| |
|
| | raise RuntimeError("At least one of min_value or max_value must be set") |
| |
|
| | class BuiltinRule: |
| | def __init__(self, content: str, deps: list | None = None): |
| | self.content = content |
| | self.deps = deps or [] |
| |
|
| | |
| | SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}' |
| |
|
| | PRIMITIVE_RULES = { |
| | 'boolean' : BuiltinRule('("true" | "false") space', []), |
| | 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), |
| | 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), |
| | 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), |
| | 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), |
| | 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), |
| | 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), |
| | 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), |
| | 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), |
| | 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []), |
| | 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), |
| | 'null' : BuiltinRule('"null" space', []), |
| | } |
| |
|
| | |
| | STRING_FORMAT_RULES = { |
| | 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), |
| | 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), |
| | 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), |
| | 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), |
| | 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), |
| | 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), |
| | } |
| |
|
| | DOTALL = '[\\U00000000-\\U0010FFFF]' |
| | DOT = '[^\\x0A\\x0D]' |
| |
|
| | RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) |
| |
|
| | INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') |
| | GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') |
| | GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') |
| | GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} |
| |
|
| | NON_LITERAL_SET = set('|.()[]{}*+?') |
| | ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') |
| |
|
| |
|
| | class SchemaConverter: |
| | def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): |
| | self._prop_order = prop_order |
| | self._allow_fetch = allow_fetch |
| | self._dotall = dotall |
| | self._raw_pattern = raw_pattern |
| | self._rules = { |
| | 'space': SPACE_RULE, |
| | } |
| | self._refs = {} |
| | self._refs_being_resolved = set() |
| |
|
| | def _format_literal(self, literal): |
| | escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( |
| | lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal |
| | ) |
| | return f'"{escaped}"' |
| |
|
| | def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: |
| | ''' |
| | not_literal('a') -> '[^a]' |
| | not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' |
| | ''' |
| | assert len(literal) > 0, 'Empty literal not supported' |
| | def recurse(i: int): |
| | c = literal[i] |
| | if maybe_escaped_underscores and c == '_': |
| | yield f'[^{c}\\\\]' |
| | yield ' | ' |
| | yield f'"\\\\"? "{c}"' |
| | else: |
| | yield f'[^{c}]' |
| | if i < len(literal) - 1: |
| | yield ' | ' |
| | yield self._format_literal(c) |
| | yield ' (' |
| | yield from recurse(i + 1) |
| | yield ')?' |
| |
|
| | return ''.join(('(', *recurse(0), ')')) |
| |
|
| | def _not_strings(self, strings): |
| | class TrieNode: |
| | def __init__(self): |
| | self.children = {} |
| | self.is_end_of_string = False |
| |
|
| | def insert(self, string): |
| | node = self |
| | for c in string: |
| | node = node.children.setdefault(c, TrieNode()) |
| | node.is_end_of_string = True |
| |
|
| | trie = TrieNode() |
| | for s in strings: |
| | trie.insert(s) |
| |
|
| | char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) |
| | out = ['["] ( '] |
| |
|
| | def visit(node): |
| | rejects = [] |
| | first = True |
| | for c in sorted(node.children.keys()): |
| | child = node.children[c] |
| | rejects.append(c) |
| | if first: |
| | first = False |
| | else: |
| | out.append(' | ') |
| | out.append(f'[{c}]') |
| | if child.children: |
| | out.append(f' (') |
| | visit(child) |
| | out.append(')') |
| | elif child.is_end_of_string: |
| | out.append(f' {char_rule}+') |
| | if node.children: |
| | if not first: |
| | out.append(' | ') |
| | out.append(f'[^"{"".join(rejects)}] {char_rule}*') |
| | visit(trie) |
| |
|
| | out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') |
| | return ''.join(out) |
| |
|
| | def _add_rule(self, name, rule): |
| | esc_name = INVALID_RULE_CHARS_RE.sub('-', name) |
| | if esc_name not in self._rules or self._rules[esc_name] == rule: |
| | key = esc_name |
| | else: |
| | i = 0 |
| | while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: |
| | i += 1 |
| | key = f'{esc_name}{i}' |
| | self._rules[key] = rule |
| | return key |
| |
|
| | def resolve_refs(self, schema: dict, url: str): |
| | ''' |
| | Resolves all $ref fields in the given schema, fetching any remote schemas, |
| | replacing $ref with absolute reference URL and populating self._refs with the |
| | respective referenced (sub)schema dictionaries. |
| | ''' |
| | def visit(n: dict): |
| | if isinstance(n, list): |
| | return [visit(x) for x in n] |
| | elif isinstance(n, dict): |
| | ref = n.get('$ref') |
| | if ref is not None and ref not in self._refs: |
| | if ref.startswith('https://'): |
| | assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' |
| | import requests |
| |
|
| | frag_split = ref.split('#') |
| | base_url = frag_split[0] |
| |
|
| | target = self._refs.get(base_url) |
| | if target is None: |
| | target = self.resolve_refs(requests.get(ref).json(), base_url) |
| | self._refs[base_url] = target |
| |
|
| | if len(frag_split) == 1 or frag_split[-1] == '': |
| | return target |
| | elif ref.startswith('#/'): |
| | target = schema |
| | ref = f'{url}{ref}' |
| | n['$ref'] = ref |
| | else: |
| | raise ValueError(f'Unsupported ref {ref}') |
| |
|
| | for sel in ref.split('#')[-1].split('/')[1:]: |
| | assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' |
| | target = target[sel] |
| |
|
| | self._refs[ref] = target |
| | else: |
| | for v in n.values(): |
| | visit(v) |
| |
|
| | return n |
| | return visit(schema) |
| |
|
| | def _generate_union_rule(self, name, alt_schemas): |
| | return ' | '.join(( |
| | self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') |
| | for i, alt_schema in enumerate(alt_schemas) |
| | )) |
| |
|
| | def _visit_pattern(self, pattern, name): |
| | ''' |
| | Transforms a regular expression pattern into a GBNF rule. |
| | |
| | Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions |
| | Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md |
| | |
| | Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. |
| | |
| | Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which |
| | we define sub-rules to keep the output lean. |
| | ''' |
| |
|
| | assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' |
| | pattern = pattern[1:-1] |
| | sub_rule_ids = {} |
| |
|
| | i = 0 |
| | length = len(pattern) |
| |
|
| | def to_rule(s: tuple[str, bool]) -> str: |
| | (txt, is_literal) = s |
| | return "\"" + txt + "\"" if is_literal else txt |
| |
|
| | def transform() -> tuple[str, bool]: |
| | ''' |
| | Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. |
| | ''' |
| | nonlocal i |
| | nonlocal pattern |
| | nonlocal sub_rule_ids |
| |
|
| | start = i |
| | |
| | |
| | |
| | |
| | seq: list[tuple[str, bool]] = [] |
| |
|
| | def get_dot(): |
| | if self._dotall: |
| | rule = DOTALL |
| | else: |
| | |
| | rule = DOT |
| | return self._add_rule(f'dot', rule) |
| |
|
| | def join_seq(): |
| | nonlocal seq |
| | ret = [] |
| | for is_literal, g in itertools.groupby(seq, lambda x: x[1]): |
| | if is_literal: |
| | ret.append((''.join(x[0] for x in g), True)) |
| | else: |
| | ret.extend(g) |
| | if len(ret) == 1: |
| | return ret[0] |
| | return (' '.join(to_rule(x) for x in seq), False) |
| |
|
| | while i < length: |
| | c = pattern[i] |
| | if c == '.': |
| | seq.append((get_dot(), False)) |
| | i += 1 |
| | elif c == '(': |
| | i += 1 |
| | if i < length: |
| | assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' |
| | seq.append((f'({to_rule(transform())})', False)) |
| | elif c == ')': |
| | i += 1 |
| | assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' |
| | return join_seq() |
| | elif c == '[': |
| | square_brackets = c |
| | i += 1 |
| | while i < length and pattern[i] != ']': |
| | if pattern[i] == '\\': |
| | square_brackets += pattern[i:i+2] |
| | i += 2 |
| | else: |
| | square_brackets += pattern[i] |
| | i += 1 |
| | assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' |
| | square_brackets += ']' |
| | i += 1 |
| | seq.append((square_brackets, False)) |
| | elif c == '|': |
| | seq.append(('|', False)) |
| | i += 1 |
| | elif c in ('*', '+', '?'): |
| | seq[-1] = (to_rule(seq[-1]) + c, False) |
| | i += 1 |
| | elif c == '{': |
| | curly_brackets = c |
| | i += 1 |
| | while i < length and pattern[i] != '}': |
| | curly_brackets += pattern[i] |
| | i += 1 |
| | assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' |
| | curly_brackets += '}' |
| | i += 1 |
| | nums = [s.strip() for s in curly_brackets[1:-1].split(',')] |
| | min_times = 0 |
| | max_times = None |
| | try: |
| | if len(nums) == 1: |
| | min_times = int(nums[0]) |
| | max_times = min_times |
| | else: |
| | assert len(nums) == 2 |
| | min_times = int(nums[0]) if nums[0] else 0 |
| | max_times = int(nums[1]) if nums[1] else None |
| | except ValueError: |
| | raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') |
| |
|
| | (sub, sub_is_literal) = seq[-1] |
| |
|
| | if not sub_is_literal: |
| | id = sub_rule_ids.get(sub) |
| | if id is None: |
| | id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) |
| | sub_rule_ids[sub] = id |
| | sub = id |
| |
|
| | seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False) |
| | else: |
| | literal = '' |
| | while i < length: |
| | if pattern[i] == '\\' and i < length - 1: |
| | next = pattern[i + 1] |
| | if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: |
| | i += 1 |
| | literal += pattern[i] |
| | i += 1 |
| | else: |
| | literal += pattern[i:i+2] |
| | i += 2 |
| | elif pattern[i] == '"' and not self._raw_pattern: |
| | literal += '\\"' |
| | i += 1 |
| | elif pattern[i] not in NON_LITERAL_SET and \ |
| | (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): |
| | literal += pattern[i] |
| | i += 1 |
| | else: |
| | break |
| | if literal: |
| | seq.append((literal, True)) |
| |
|
| | return join_seq() |
| |
|
| | return self._add_rule( |
| | name, |
| | to_rule(transform()) if self._raw_pattern \ |
| | else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") |
| |
|
| |
|
| | def _resolve_ref(self, ref): |
| | ref_name = ref.split('/')[-1] |
| | if ref_name not in self._rules and ref not in self._refs_being_resolved: |
| | self._refs_being_resolved.add(ref) |
| | resolved = self._refs[ref] |
| | ref_name = self.visit(resolved, ref_name) |
| | self._refs_being_resolved.remove(ref) |
| | return ref_name |
| |
|
| | def _generate_constant_rule(self, value): |
| | return self._format_literal(json.dumps(value)) |
| |
|
| | def visit(self, schema, name): |
| | schema_type = schema.get('type') |
| | schema_format = schema.get('format') |
| | rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' |
| |
|
| | if (ref := schema.get('$ref')) is not None: |
| | return self._add_rule(rule_name, self._resolve_ref(ref)) |
| |
|
| | elif 'oneOf' in schema or 'anyOf' in schema: |
| | return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) |
| |
|
| | elif isinstance(schema_type, list): |
| | return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) |
| |
|
| | elif 'const' in schema: |
| | return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') |
| |
|
| | elif 'enum' in schema: |
| | rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' |
| | return self._add_rule(rule_name, rule) |
| |
|
| | elif schema_type in (None, 'object') and \ |
| | ('properties' in schema or \ |
| | ('additionalProperties' in schema and schema['additionalProperties'] is not True)): |
| | required = set(schema.get('required', [])) |
| | properties = list(schema.get('properties', {}).items()) |
| | return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) |
| |
|
| | elif schema_type in (None, 'object', 'string') and 'allOf' in schema: |
| | required = set() |
| | properties = [] |
| | enum_sets = [] |
| | hybrid_name = name |
| | def add_component(comp_schema, is_required): |
| | if (ref := comp_schema.get('$ref')) is not None: |
| | comp_schema = self._refs[ref] |
| |
|
| | if 'properties' in comp_schema: |
| | for prop_name, prop_schema in comp_schema['properties'].items(): |
| | properties.append((prop_name, prop_schema)) |
| | if is_required: |
| | required.add(prop_name) |
| |
|
| | if 'enum' in comp_schema: |
| | enum_sets.append(set(comp_schema['enum'])) |
| |
|
| | for t in schema['allOf']: |
| | if 'anyOf' in t: |
| | for tt in t['anyOf']: |
| | add_component(tt, is_required=False) |
| | else: |
| | add_component(t, is_required=True) |
| |
|
| | if enum_sets: |
| | enum_intersection = enum_sets[0] |
| | for s in enum_sets[1:]: |
| | enum_intersection &= s |
| |
|
| | if enum_intersection: |
| | rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space' |
| | return self._add_rule(rule_name, rule) |
| |
|
| | return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) |
| |
|
| | elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): |
| | items = schema.get('items') or schema['prefixItems'] |
| | if isinstance(items, list): |
| | return self._add_rule( |
| | rule_name, |
| | '"[" space ' + |
| | ' "," space '.join( |
| | self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') |
| | for i, item in enumerate(items)) + |
| | ' "]" space') |
| | else: |
| | item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') |
| | min_items = schema.get("minItems", 0) |
| | max_items = schema.get("maxItems") |
| | return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') |
| |
|
| | elif schema_type in (None, 'string') and 'pattern' in schema: |
| | return self._visit_pattern(schema['pattern'], rule_name) |
| |
|
| | elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): |
| | return self._add_primitive( |
| | 'root' if rule_name == 'root' else schema_format, |
| | PRIMITIVE_RULES['uuid'] |
| | ) |
| |
|
| | elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: |
| | prim_name = f'{schema_format}-string' |
| | return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) |
| |
|
| | elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): |
| | char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) |
| | min_len = schema.get('minLength', 0) |
| | max_len = schema.get('maxLength') |
| |
|
| | return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') |
| |
|
| | elif schema_type in (None, 'integer') and \ |
| | ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): |
| | min_value = None |
| | max_value = None |
| | if 'minimum' in schema: |
| | min_value = schema['minimum'] |
| | elif 'exclusiveMinimum' in schema: |
| | min_value = schema['exclusiveMinimum'] + 1 |
| | if 'maximum' in schema: |
| | max_value = schema['maximum'] |
| | elif 'exclusiveMaximum' in schema: |
| | max_value = schema['exclusiveMaximum'] - 1 |
| |
|
| | out = ["("] |
| | _generate_min_max_int(min_value, max_value, out) |
| | out.append(") space") |
| | return self._add_rule(rule_name, ''.join(out)) |
| |
|
| | elif (schema_type == 'object') or (len(schema) == 0): |
| | return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) |
| |
|
| | else: |
| | assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' |
| | |
| | return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) |
| |
|
| | def _add_primitive(self, name: str, rule: BuiltinRule): |
| | n = self._add_rule(name, rule.content) |
| |
|
| | for dep in rule.deps: |
| | dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) |
| | assert dep_rule, f'Rule {dep} not known' |
| | if dep not in self._rules: |
| | self._add_primitive(dep, dep_rule) |
| | return n |
| |
|
| | def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]): |
| | prop_order = self._prop_order |
| | |
| | sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] |
| |
|
| | prop_kv_rule_names = {} |
| | for prop_name, prop_schema in properties: |
| | prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') |
| | prop_kv_rule_names[prop_name] = self._add_rule( |
| | f'{name}{"-" if name else ""}{prop_name}-kv', |
| | fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' |
| | ) |
| | required_props = [k for k in sorted_props if k in required] |
| | optional_props = [k for k in sorted_props if k not in required] |
| |
|
| | if additional_properties is not None and additional_properties != False: |
| | sub_name = f'{name}{"-" if name else ""}additional' |
| | value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \ |
| | self._add_primitive('value', PRIMITIVE_RULES['value']) |
| | key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \ |
| | else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props)) |
| |
|
| | prop_kv_rule_names["*"] = self._add_rule( |
| | f'{sub_name}-kv', |
| | f'{key_rule} ":" space {value_rule}' |
| | ) |
| | optional_props.append("*") |
| |
|
| | rule = '"{" space ' |
| | rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) |
| |
|
| | if optional_props: |
| | rule += ' (' |
| | if required_props: |
| | rule += ' "," space ( ' |
| |
|
| | def get_recursive_refs(ks, first_is_optional): |
| | [k, *rest] = ks |
| | kv_rule_name = prop_kv_rule_names[k] |
| | comma_ref = f'( "," space {kv_rule_name} )' |
| | if first_is_optional: |
| | res = comma_ref + ('*' if k == '*' else '?') |
| | else: |
| | res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '') |
| | if len(rest) > 0: |
| | res += ' ' + self._add_rule( |
| | f'{name}{"-" if name else ""}{k}-rest', |
| | get_recursive_refs(rest, first_is_optional=True) |
| | ) |
| | return res |
| |
|
| | rule += ' | '.join( |
| | get_recursive_refs(optional_props[i:], first_is_optional=False) |
| | for i in range(len(optional_props)) |
| | ) |
| | if required_props: |
| | rule += ' )' |
| | rule += ' )?' |
| |
|
| | rule += ' "}" space' |
| |
|
| | return rule |
| |
|
| | def format_grammar(self): |
| | return '\n'.join( |
| | f'{name} ::= {rule}' |
| | for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) |
| | ) |
| |
|
| |
|
| | def main(args_in = None): |
| | parser = argparse.ArgumentParser( |
| | description=''' |
| | Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a |
| | given JSON schema. Only a subset of JSON schema features are supported; more may be |
| | added in the future. |
| | ''', |
| | ) |
| | parser.add_argument( |
| | '--prop-order', |
| | default=[], |
| | type=lambda s: s.split(','), |
| | help=''' |
| | comma-separated property names defining the order of precedence for object properties; |
| | properties not specified here are given lower precedence than those that are, and |
| | are kept in their original order from the schema. Required properties are always |
| | given precedence over optional properties. |
| | ''' |
| | ) |
| | parser.add_argument( |
| | '--allow-fetch', |
| | action='store_true', |
| | default=False, |
| | help='Whether to allow fetching referenced schemas over HTTPS') |
| | parser.add_argument( |
| | '--dotall', |
| | action='store_true', |
| | default=False, |
| | help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns') |
| | parser.add_argument( |
| | '--raw-pattern', |
| | action='store_true', |
| | default=False, |
| | help='Treats string patterns as raw patterns w/o quotes (or quote escapes)') |
| |
|
| | parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') |
| | args = parser.parse_args(args_in) |
| |
|
| | if args.schema.startswith('https://'): |
| | url = args.schema |
| | import requests |
| | schema = requests.get(url).json() |
| | elif args.schema == '-': |
| | url = 'stdin' |
| | schema = json.load(sys.stdin) |
| | else: |
| | url = f'file://{args.schema}' |
| | with open(args.schema) as f: |
| | schema = json.load(f) |
| | converter = SchemaConverter( |
| | prop_order={name: idx for idx, name in enumerate(args.prop_order)}, |
| | allow_fetch=args.allow_fetch, |
| | dotall=args.dotall, |
| | raw_pattern=args.raw_pattern) |
| | schema = converter.resolve_refs(schema, url) |
| | converter.visit(schema, '') |
| | print(converter.format_grammar()) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|