From e598e9233a1d0559cab6d82e0752ce00cf3b887b Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 19 Aug 2019 18:34:57 -0700 Subject: [PATCH] Map batch matmul from PT to TVM. Also added test. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_operators.py | 14 ++++++++++++++ torch_tvm/operators.cpp | 20 ++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/test/test_operators.py b/test/test_operators.py index 4570d52..a6a850e 100644 --- a/test/test_operators.py +++ b/test/test_operators.py @@ -326,6 +326,20 @@ def linear_no_bias(input, weight): ref_out_no_bias, tvm_out_no_bias = self.runBoth(linear_no_bias, input, weight) assert torch.allclose(ref_out_no_bias, tvm_out_no_bias, rtol=0.01, atol=0.01) + @TVMTest.given( + shape=TVMTest.rand_shape(rank=3, min_dim=4), + out_features=TVMTest.rand_int(3, 6), + ) + def test_bmm(self, shape, out_features): + input = torch.rand(shape) + weight = torch.rand(shape[0], shape[-1], out_features) + + def bmm(input, weight): + return torch.bmm(input, weight) + 2.0 + + ref_out, tvm_out = self.runBoth(bmm, input, weight) + assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) + @TVMTest.given( shape=TVMTest.rand_shape(rank=2, min_dim=4), ) diff --git a/torch_tvm/operators.cpp b/torch_tvm/operators.cpp index b47f059..30c8623 100644 --- a/torch_tvm/operators.cpp +++ b/torch_tvm/operators.cpp @@ -503,6 +503,26 @@ RegisterTVMOperator reg({ tvm::relay::CallNode::make(op, {inputs[0]}, tvm::Attrs(attrs), {}); return out; }}, + {Symbol::fromQualString("aten::bmm"), + [](Node* node, tvm::Array inputs) { + TORCH_INTERNAL_ASSERT(inputs.size()==2); + + auto transpose_attrs = tvm::make_node(); + auto& axes = transpose_attrs->axes; + axes.push_back(0); + axes.push_back(2); + axes.push_back(1); + + auto transposed_weight = tvm::relay::CallNode::make( + tvm::relay::Op::Get("transpose"), + {inputs[1]}, tvm::Attrs(transpose_attrs), {}); + + auto out = tvm::relay::CallNode::make( + tvm::relay::Op::Get("nn.batch_matmul"), + {inputs[0], transposed_weight}, {}, {}); + + return out; + }}, {Symbol::fromQualString("aten::linear"), [](Node* node, tvm::Array inputs) { Value* input = node->input(0);