| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include "eliminate_reshape_shape_expression.h" |
|
|
| #include <iostream> |
| #include <sstream> |
| #include <algorithm> |
| #include <stack> |
| #include <vector> |
| #include <string> |
|
|
| namespace pnnx { |
|
|
| static bool token_is_interger_literal(const std::string& t) |
| { |
| std::istringstream iss(t); |
| int f; |
| iss >> std::noskipws >> f; |
| return iss.eof() && !iss.fail(); |
| } |
|
|
| static void build_shape(const std::string& expr, std::vector<int>& shape, std::vector<std::string>& expr_tokens) |
| { |
| std::string listexpr = expr.substr(1, expr.size() - 2); |
|
|
| std::string t; |
| std::string et; |
| int level = 0; |
| for (size_t i = 0; i < listexpr.size(); i++) |
| { |
| char ch = listexpr[i]; |
|
|
| if (ch == '(' || ch == '[') |
| { |
| level += 1; |
| t = "-1"; |
| et += ch; |
| } |
| else if (ch == ')' || ch == ']') |
| { |
| level -= 1; |
| t = "-1"; |
| et += ch; |
| } |
| else if (level == 0 && ch == ',') |
| { |
| int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1; |
| shape.push_back(dimsize); |
| expr_tokens.push_back(et); |
| t.clear(); |
| et.clear(); |
| } |
| else |
| { |
| t += ch; |
| et += ch; |
| } |
| } |
|
|
| if (level == 0 && !t.empty()) |
| { |
| int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1; |
| shape.push_back(dimsize); |
| } |
|
|
| if (level == 0 && !et.empty()) |
| { |
| expr_tokens.push_back(et); |
| } |
| } |
|
|
| static std::string build_expr(const std::vector<std::string>& expr_tokens) |
| { |
| std::string expr; |
|
|
| expr += '['; |
| for (int i = 0; i < (int)expr_tokens.size(); i++) |
| { |
| expr += expr_tokens[i]; |
| if (i != (int)expr_tokens.size() - 1) |
| expr += ','; |
| } |
| expr += ']'; |
|
|
| return expr; |
| } |
|
|
| void eliminate_reshape_shape_expression(Graph& graph) |
| { |
| while (1) |
| { |
| bool matched = false; |
|
|
| for (size_t i = 0; i < graph.ops.size(); i++) |
| { |
| Operator* op = graph.ops[i]; |
|
|
| if (op->type != "Tensor.view" && op->type != "Tensor.reshape") |
| continue; |
|
|
| if (op->inputs.size() != 2) |
| continue; |
|
|
| Operator* op_expr = op->inputs[1]->producer; |
| if (op_expr->type != "pnnx.Expression") |
| continue; |
|
|
| std::string expr = op_expr->params.at("expr").s; |
| if (expr.empty() || expr[0] != '[') |
| continue; |
|
|
| std::vector<int> outshape = op->outputs[0]->shape; |
| if (outshape.empty()) |
| continue; |
|
|
| std::vector<int> shape; |
| std::vector<std::string> expr_tokens; |
| build_shape(expr, shape, expr_tokens); |
|
|
| |
| for (size_t j = 0; j < outshape.size(); j++) |
| { |
| if (outshape[j] != -1) |
| { |
| shape[j] = outshape[j]; |
| expr_tokens[j] = std::to_string(outshape[j]); |
| } |
| } |
|
|
| |
| int dynamic_dim_count = 0; |
| for (size_t j = 0; j < shape.size(); j++) |
| { |
| if (shape[j] == -1) |
| { |
| dynamic_dim_count += 1; |
| } |
| } |
|
|
| if (dynamic_dim_count > 1) |
| { |
| op_expr->params["expr"] = build_expr(expr_tokens); |
| continue; |
| } |
|
|
| matched = true; |
|
|
| op->params["shape"] = shape; |
|
|
| op->inputs.resize(1); |
| op_expr->outputs[0]->remove_consumer(op); |
|
|
| if (op_expr->outputs[0]->consumers.size() == 0) |
| { |
| |
| for (auto x : op_expr->inputs) |
| { |
| x->remove_consumer(op_expr); |
| } |
|
|
| Operand* op_expr_out = op_expr->outputs[0]; |
|
|
| graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_expr_out)); |
| delete op_expr_out; |
|
|
| op_expr->inputs.clear(); |
| op_expr->outputs.clear(); |
|
|
| graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_expr)); |
| delete op_expr; |
| } |
|
|
| break; |
| } |
|
|
| if (!matched) |
| break; |
| } |
|
|
| for (size_t i = 0; i < graph.ops.size(); i++) |
| { |
| Operator* op = graph.ops[i]; |
|
|
| if (op->type != "Tensor.view" && op->type != "Tensor.reshape") |
| continue; |
|
|
| if (op->inputs.size() != 1) |
| continue; |
|
|
| std::vector<int> outshape = op->outputs[0]->shape; |
| if (outshape.empty()) |
| continue; |
|
|
| std::vector<int> shape = op->params.at("shape").ai; |
|
|
| |
| for (size_t j = 0; j < outshape.size(); j++) |
| { |
| if (outshape[j] != -1) |
| { |
| shape[j] = outshape[j]; |
| } |
| } |
|
|
| op->params["shape"] = shape; |
| } |
| } |
|
|
| } |
|
|