sssdtgvg commited on
Commit
f385b17
1 Parent(s): 5178306

Upload tfcompile_main.cc

Browse files
Files changed (1) hide show
  1. tfcompile_main.cc +143 -0
tfcompile_main.cc ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include <memory>
17
+ #include <string>
18
+ #include <utility>
19
+ #include <vector>
20
+
21
+ #include "tensorflow/compiler/aot/codegen.h"
22
+ #include "tensorflow/compiler/aot/compile.h"
23
+ #include "tensorflow/compiler/aot/flags.h"
24
+ #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
25
+ #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
26
+ #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
27
+ #include "tensorflow/compiler/xla/service/compiler.h"
28
+ #include "tensorflow/core/framework/function.h"
29
+ #include "tensorflow/core/framework/graph.pb.h"
30
+ #include "tensorflow/core/framework/tensor_shape.h"
31
+ #include "tensorflow/core/framework/types.h"
32
+ #include "tensorflow/core/graph/graph.h"
33
+ #include "tensorflow/core/graph/tensor_id.h"
34
+ #include "tensorflow/core/lib/core/errors.h"
35
+ #include "tensorflow/core/lib/core/stringpiece.h"
36
+ #include "tensorflow/core/lib/strings/numbers.h"
37
+ #include "tensorflow/core/lib/strings/str_util.h"
38
+ #include "tensorflow/core/platform/env.h"
39
+ #include "tensorflow/core/platform/init_main.h"
40
+ #include "tensorflow/core/platform/logging.h"
41
+ #include "tensorflow/core/platform/protobuf.h"
42
+ #include "tensorflow/core/util/command_line_flags.h"
43
+
44
+ namespace tensorflow {
45
+ namespace tfcompile {
46
+
47
+ const char kUsageHeader[] =
48
+ "tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n"
49
+ "resulting in an object file compiled for your target architecture, and a\n"
50
+ "header file that gives access to the functionality in the object file.\n"
51
+ "A typical invocation looks like this:\n"
52
+ "\n"
53
+ " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt "
54
+ "--cpp_class=\"mynamespace::MyComputation\"\n"
55
+ "\n";
56
+
57
+ Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
58
+ if (StringPiece(fname).ends_with(".pbtxt")) {
59
+ return ReadTextProto(Env::Default(), fname, proto);
60
+ } else {
61
+ return ReadBinaryProto(Env::Default(), fname, proto);
62
+ }
63
+ }
64
+
65
+ Status Main(const MainFlags& flags) {
66
+ // Process config.
67
+ tf2xla::Config config;
68
+ if (flags.config.empty()) {
69
+ return errors::InvalidArgument("Must specify --config");
70
+ }
71
+ TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
72
+ TF_RETURN_IF_ERROR(ValidateConfig(config));
73
+ if (flags.dump_fetch_nodes) {
74
+ std::set<string> nodes;
75
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
76
+ nodes.insert(fetch.id().node_name());
77
+ }
78
+ std::cout << str_util::Join(nodes, ",");
79
+ return Status::OK();
80
+ }
81
+
82
+ // Read and initialize the graph.
83
+ if (flags.graph.empty()) {
84
+ return errors::InvalidArgument("Must specify --graph");
85
+ }
86
+ GraphDef graph_def;
87
+ TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
88
+ CompileResult compile_result;
89
+ TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
90
+
91
+ // Write output files.
92
+ Env* env = Env::Default();
93
+ const std::vector<char>& obj = compile_result.aot->object_file_data();
94
+ TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
95
+ StringPiece(obj.data(), obj.size())));
96
+ HeaderOpts header_opts;
97
+ header_opts.gen_name_to_index = flags.gen_name_to_index;
98
+ header_opts.gen_program_shape = flags.gen_program_shape;
99
+ if (flags.cpp_class.empty()) {
100
+ return errors::InvalidArgument("Must specify --cpp_class");
101
+ }
102
+ TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
103
+ &header_opts.namespaces));
104
+ string header;
105
+ TF_RETURN_IF_ERROR(
106
+ GenerateHeader(header_opts, config, compile_result, &header));
107
+ TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
108
+ return Status::OK();
109
+ }
110
+
111
+ } // end namespace tfcompile
112
+ } // end namespace tensorflow
113
+
114
+ int main(int argc, char** argv) {
115
+ tensorflow::tfcompile::MainFlags flags;
116
+ flags.target_triple = "x86_64-pc-linux";
117
+ flags.out_object = "out.o";
118
+ flags.out_header = "out.h";
119
+ flags.entry_point = "entry";
120
+
121
+ std::vector<tensorflow::Flag> flag_list;
122
+ AppendMainFlags(&flag_list, &flags);
123
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
124
+
125
+ tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
126
+ usage += tensorflow::Flags::Usage(argv[0], flag_list);
127
+ bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
128
+ QCHECK(parsed_flags_ok) << "\n" << usage;
129
+
130
+ tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
131
+ QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
132
+ "other than flags\n\n"
133
+ << usage;
134
+ tensorflow::Status status = tensorflow::tfcompile::Main(flags);
135
+ if (status.code() == tensorflow::error::INVALID_ARGUMENT) {
136
+ std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"
137
+ << usage;
138
+ return 1;
139
+ } else {
140
+ TF_QCHECK_OK(status);
141
+ }
142
+ return 0;
143
+ }