Spaces:
Sleeping
Sleeping
Upload tfcompile_main.cc
Browse files- tfcompile_main.cc +143 -0
tfcompile_main.cc
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include <memory>
|
17 |
+
#include <string>
|
18 |
+
#include <utility>
|
19 |
+
#include <vector>
|
20 |
+
|
21 |
+
#include "tensorflow/compiler/aot/codegen.h"
|
22 |
+
#include "tensorflow/compiler/aot/compile.h"
|
23 |
+
#include "tensorflow/compiler/aot/flags.h"
|
24 |
+
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
25 |
+
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
26 |
+
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
27 |
+
#include "tensorflow/compiler/xla/service/compiler.h"
|
28 |
+
#include "tensorflow/core/framework/function.h"
|
29 |
+
#include "tensorflow/core/framework/graph.pb.h"
|
30 |
+
#include "tensorflow/core/framework/tensor_shape.h"
|
31 |
+
#include "tensorflow/core/framework/types.h"
|
32 |
+
#include "tensorflow/core/graph/graph.h"
|
33 |
+
#include "tensorflow/core/graph/tensor_id.h"
|
34 |
+
#include "tensorflow/core/lib/core/errors.h"
|
35 |
+
#include "tensorflow/core/lib/core/stringpiece.h"
|
36 |
+
#include "tensorflow/core/lib/strings/numbers.h"
|
37 |
+
#include "tensorflow/core/lib/strings/str_util.h"
|
38 |
+
#include "tensorflow/core/platform/env.h"
|
39 |
+
#include "tensorflow/core/platform/init_main.h"
|
40 |
+
#include "tensorflow/core/platform/logging.h"
|
41 |
+
#include "tensorflow/core/platform/protobuf.h"
|
42 |
+
#include "tensorflow/core/util/command_line_flags.h"
|
43 |
+
|
44 |
+
namespace tensorflow {
|
45 |
+
namespace tfcompile {
|
46 |
+
|
47 |
+
const char kUsageHeader[] =
|
48 |
+
"tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n"
|
49 |
+
"resulting in an object file compiled for your target architecture, and a\n"
|
50 |
+
"header file that gives access to the functionality in the object file.\n"
|
51 |
+
"A typical invocation looks like this:\n"
|
52 |
+
"\n"
|
53 |
+
" $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt "
|
54 |
+
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
55 |
+
"\n";
|
56 |
+
|
57 |
+
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
58 |
+
if (StringPiece(fname).ends_with(".pbtxt")) {
|
59 |
+
return ReadTextProto(Env::Default(), fname, proto);
|
60 |
+
} else {
|
61 |
+
return ReadBinaryProto(Env::Default(), fname, proto);
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
Status Main(const MainFlags& flags) {
|
66 |
+
// Process config.
|
67 |
+
tf2xla::Config config;
|
68 |
+
if (flags.config.empty()) {
|
69 |
+
return errors::InvalidArgument("Must specify --config");
|
70 |
+
}
|
71 |
+
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
72 |
+
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
73 |
+
if (flags.dump_fetch_nodes) {
|
74 |
+
std::set<string> nodes;
|
75 |
+
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
76 |
+
nodes.insert(fetch.id().node_name());
|
77 |
+
}
|
78 |
+
std::cout << str_util::Join(nodes, ",");
|
79 |
+
return Status::OK();
|
80 |
+
}
|
81 |
+
|
82 |
+
// Read and initialize the graph.
|
83 |
+
if (flags.graph.empty()) {
|
84 |
+
return errors::InvalidArgument("Must specify --graph");
|
85 |
+
}
|
86 |
+
GraphDef graph_def;
|
87 |
+
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
88 |
+
CompileResult compile_result;
|
89 |
+
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
|
90 |
+
|
91 |
+
// Write output files.
|
92 |
+
Env* env = Env::Default();
|
93 |
+
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
94 |
+
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
|
95 |
+
StringPiece(obj.data(), obj.size())));
|
96 |
+
HeaderOpts header_opts;
|
97 |
+
header_opts.gen_name_to_index = flags.gen_name_to_index;
|
98 |
+
header_opts.gen_program_shape = flags.gen_program_shape;
|
99 |
+
if (flags.cpp_class.empty()) {
|
100 |
+
return errors::InvalidArgument("Must specify --cpp_class");
|
101 |
+
}
|
102 |
+
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
|
103 |
+
&header_opts.namespaces));
|
104 |
+
string header;
|
105 |
+
TF_RETURN_IF_ERROR(
|
106 |
+
GenerateHeader(header_opts, config, compile_result, &header));
|
107 |
+
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
108 |
+
return Status::OK();
|
109 |
+
}
|
110 |
+
|
111 |
+
} // end namespace tfcompile
|
112 |
+
} // end namespace tensorflow
|
113 |
+
|
114 |
+
int main(int argc, char** argv) {
|
115 |
+
tensorflow::tfcompile::MainFlags flags;
|
116 |
+
flags.target_triple = "x86_64-pc-linux";
|
117 |
+
flags.out_object = "out.o";
|
118 |
+
flags.out_header = "out.h";
|
119 |
+
flags.entry_point = "entry";
|
120 |
+
|
121 |
+
std::vector<tensorflow::Flag> flag_list;
|
122 |
+
AppendMainFlags(&flag_list, &flags);
|
123 |
+
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
|
124 |
+
|
125 |
+
tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
|
126 |
+
usage += tensorflow::Flags::Usage(argv[0], flag_list);
|
127 |
+
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
128 |
+
QCHECK(parsed_flags_ok) << "\n" << usage;
|
129 |
+
|
130 |
+
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
131 |
+
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
|
132 |
+
"other than flags\n\n"
|
133 |
+
<< usage;
|
134 |
+
tensorflow::Status status = tensorflow::tfcompile::Main(flags);
|
135 |
+
if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
136 |
+
std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
|
137 |
+
<< usage;
|
138 |
+
return 1;
|
139 |
+
} else {
|
140 |
+
TF_QCHECK_OK(status);
|
141 |
+
}
|
142 |
+
return 0;
|
143 |
+
}
|