/* WaveGRU: > Embed > GRU > O1 > O2 > Sampling > ... */ #include #include #include #include #include #include #include "sparse_matmul/sparse_matmul.h" namespace py = pybind11; using namespace std; using fvec = std::vector; using ivec = std::vector; using fndarray = py::array_t; using indarray = py::array_t; using mat = csrblocksparse::CsrBlockSparseMatrix; using vec = csrblocksparse::CacheAlignedVector; using masked_mat = csrblocksparse::MaskedSparseMatrix; mat create_mat(int h, int w) { auto m = masked_mat(w, h, 0.90, 4, 4, 0.0, true); auto a = mat(m); return a; } struct WaveGRU { int input_dim; int embed_dim; int hidden_dim; mat m1, m2, m3; vec b1, b2, b3; vec z, r, hh; vec fco1, fco2; vec o1b, o2b; vec t; vec h; mat o1, o2; std::vector embed; WaveGRU(int input_dim, int embed_dim, int hidden_dim) : input_dim(input_dim), embed_dim(embed_dim), hidden_dim(hidden_dim), b1(hidden_dim), b2(hidden_dim), b3(hidden_dim), z(hidden_dim), r(hidden_dim), hh(hidden_dim), fco1(hidden_dim), fco2(256), t(hidden_dim + input_dim + embed_dim), h(hidden_dim), o1b(hidden_dim), o2b(256) { m1 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim); m2 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim); m3 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim); o1 = create_mat(hidden_dim, hidden_dim); o2 = create_mat(hidden_dim, 256); embed = std::vector(); for (int i = 0; i < 256; i++) { embed.emplace_back(embed_dim); embed[i].FillRandom(); } } void load_embed(fndarray embed_weights) { auto a_embed = embed_weights.unchecked<2>(); for (int i = 0; i < 256; i++) { for (int j = 0; j < embed_dim; j++) embed[i][j] = a_embed(i, j); } } mat load_linear(vec& bias, fndarray w, indarray mask, fndarray b) { auto w_ptr = static_cast(w.request().ptr); auto mask_ptr = static_cast(mask.request().ptr); auto rb = b.unchecked<1>(); // load bias, scale by 1/4 for (int i = 0; i < rb.shape(0); i++) bias[i] = rb(i) / 4; // load weights masked_mat mm(w.shape(0), w.shape(1), mask_ptr, w_ptr); mat mmm(mm); return mmm; } void load_weights(fndarray m1, indarray m1_mask, fndarray b1, fndarray m2, indarray m2_mask, fndarray b2, fndarray m3, indarray m3_mask, fndarray b3, fndarray o1, indarray o1_mask, fndarray o1b, fndarray o2, indarray o2_mask, fndarray o2b) { this->m1 = load_linear(this->b1, m1, m1_mask, b1); this->m2 = load_linear(this->b2, m2, m2_mask, b2); this->m3 = load_linear(this->b3, m3, m3_mask, b3); this->o1 = load_linear(this->o1b, o1, o1_mask, o1b); this->o2 = load_linear(this->o2b, o2, o2_mask, o2b); } std::vector inference(fndarray ft, float temperature) { auto rft = ft.unchecked<2>(); std::vector xs; for (int i = 0; i < rft.shape(0); i++) { xs.emplace_back(input_dim); for (int j = 0; j < input_dim; j++) xs[i][j] = rft(i, j); } int value = 127; std::vector signal(xs.size()); h.FillZero(); for (int index = 0; index < xs.size(); index++) { for (int i = 0; i < embed_dim; i++) t[i] = embed[value][i]; for (int i = 0; i < input_dim; i++) t[embed_dim + i] = xs[index][i]; for (int i = 0; i < hidden_dim; i++) t[embed_dim + input_dim + i] = h[i]; m1.SpMM_bias(t, b1, &z, false); m2.SpMM_bias(t, b2, &r, false); z.Sigmoid(); r.Sigmoid(); for (int i = 0; i < hidden_dim; i++) { t[embed_dim + input_dim + i] = h[i] * r[i]; } m3.SpMM_bias(t, b3, &hh, false); hh.Tanh(); for (int i = 0; i < hidden_dim; i++) { h[i] = (1. - z[i]) * h[i] + z[i] * hh[i]; } o1.SpMM_bias(h, o1b, &fco1, true); o2.SpMM_bias(fco1, o2b, &fco2, false); value = fco2.Sample(temperature); signal[index] = value; } return signal; } }; PYBIND11_MODULE(wavegru_mod, m) { py::class_(m, "WaveGRU") .def(py::init()) .def("load_embed", &WaveGRU::load_embed) .def("load_weights", &WaveGRU::load_weights) .def("inference", &WaveGRU::inference); }