File size: 4,676 Bytes
8b7c501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
# Copyright 2019 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import codecs
import io
import os
import re
import sys
from itertools import chain


def key_value_pair(line):
  key, value = line.split("=", 1)
  # represent value as integer, if possible, otherwise as str
  try:
    value = int(value)
  except ValueError:
    pass
  return key, value


parser = argparse.ArgumentParser(description='XNNPACK generator')
parser.add_argument("input", metavar="FILE", nargs=1,
          help="Input file")
parser.add_argument("-D", dest="defines", metavar="KEY=VALUE", nargs="*",
          type=key_value_pair, action="append",
          help="Predefined variables")
parser.add_argument("-o", "--output",
          help='Output file')
parser.set_defaults(defines=list())


LEADING_WHITESPACE_REGEX = re.compile(r"^\s*", flags=0)


def extract_leading_whitespace(line):
  match = re.match(r"\s*", line)
  return match.group(0) if match else ""


def escape(line):
  output_parts = []
  while "${" in line:
    start_pos = line.index("${")
    end_pos = line.index("}", start_pos + 2)
    if start_pos != 0:
      output_parts.append("\"" + line[:start_pos].replace("\"", "\\\"") + "\"")
    output_parts.append("str(" + line[start_pos+2:end_pos] + ")")
    line = line[end_pos+1:]
  if line:
    output_parts.append("\"" + line.replace("\"", "\\\"") + "\"")
  return " + ".join(output_parts)


def preprocess(input_text, input_globals, input_path="codegen"):
  input_lines = input_text.splitlines()
  python_lines = []

  blank_lines = 0

  last_line = ""
  last_indent = ""

  # List of tuples (total_index, python_indent)
  indent_stack = [("", "")]

  # Indicates whether this is the first line inside Python
  # code block (i.e. for, while, if, elif, else)
  python_block_start = True
  for i, input_line in enumerate(input_lines):
    if input_line == "":
      blank_lines += 1
      continue
    # Skip lint markers.
    if 'LINT' in input_line:
      continue

    input_indent = extract_leading_whitespace(input_line)
    if python_block_start:
      assert input_indent.startswith(last_indent)
      extra_python_indent = input_indent[len(last_indent):]
      python_indent = indent_stack[-1][1] + extra_python_indent
      indent_stack.append((input_indent, python_indent))
      assert input_indent.startswith(indent_stack[-1][0])
    else:
      while not input_indent.startswith(indent_stack[-1][0]):
        del indent_stack[-1]
    python_block_start = False

    python_indent = indent_stack[-1][1]
    stripped_input_line = input_line.strip()
    if stripped_input_line.startswith("$") and not stripped_input_line.startswith("${"):
      if stripped_input_line.endswith(":"):
        python_block_start = True
      while blank_lines != 0:
        python_lines.append(python_indent + "print(file=OUT_STREAM)")
        blank_lines -= 1
      python_lines.append(python_indent + stripped_input_line.replace("$", ""))
    else:
      assert input_line.startswith(python_indent)
      while blank_lines != 0:
        python_lines.append(python_indent + "print(file=OUT_STREAM)")
        blank_lines -= 1
      python_lines.append(python_indent + "print(%s, file=OUT_STREAM)" % escape(input_line[len(python_indent):]))
    last_line = input_line
    last_indent = input_indent

  while blank_lines != 0:
    python_lines.append(python_indent + "print(file=OUT_STREAM)")
    blank_lines -= 1

  exec_globals = dict(input_globals)
  if sys.version_info > (3, 0):
    output_stream = io.StringIO()
  else:
    output_stream = io.BytesIO()
  exec_globals["OUT_STREAM"] = output_stream
  python_bytecode = compile("\n".join(python_lines), input_path, 'exec')
  exec(python_bytecode, exec_globals)

  return output_stream.getvalue()


PREAMBLE = """\
// Auto-generated file. Do not edit!
//   Template: {template}
//   Generator: {generator}
//
"""


def main(args):
  options = parser.parse_args(args)

  input_text = codecs.open(options.input[0], "r", encoding="utf-8").read()
  python_globals = dict(chain(*options.defines))
  output_text = PREAMBLE.format(template=options.input[0], generator=sys.argv[0]) + preprocess(input_text, python_globals, options.input[0])

  txt_changed = True
  if os.path.exists(options.output):
    with codecs.open(options.output, "r", encoding="utf-8") as output_file:
      txt_changed = output_file.read() != output_text

  if txt_changed:
    with codecs.open(options.output, "w", encoding="utf-8") as output_file:
      output_file.write(output_text)

if __name__ == "__main__":
  main(sys.argv[1:])