|
|
@@ -128,12 +128,14 @@ namespace traph
|
|
|
dim.push_back(b.size()[1]);
|
|
|
std::shared_ptr<Tensor<f32>> result(new Tensor<f32>(dim));
|
|
|
|
|
|
+
|
|
|
+
|
|
|
#ifdef TRAPH_BUILD_EIGEN
|
|
|
// copy data
|
|
|
- Eigen::Map<const Eigen::Matrix<f32, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> eigen_a(a.data_ptr() + a.offset(), a.size()[0], a.size()[1]);
|
|
|
- Eigen::Map<const Eigen::Matrix<f32, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> eigen_b(b.data_ptr() + b.offset(), b.size()[0], b.size()[1]);
|
|
|
+ Eigen::Map<const Eigen::Matrix<f32, Eigen::Dynamic, Eigen::Dynamic>, 0, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>> eigen_a(a.data_ptr() + a.offset(), a.size()[0], a.size()[1], Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>(a.stride(1), a.stride(0)));
|
|
|
+ Eigen::Map<const Eigen::Matrix<f32, Eigen::Dynamic, Eigen::Dynamic>, 0, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>> eigen_b(b.data_ptr() + b.offset(), b.size()[0], b.size()[1], Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>(b.stride(1), b.stride(0)));
|
|
|
|
|
|
- Eigen::Matrix<f32, Eigen::Dynamic, Eigen::Dynamic> eigen_c = eigen_a * eigen_b;
|
|
|
+ Eigen::Matrix<f32, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> eigen_c = eigen_a * eigen_b;
|
|
|
// copy to result
|
|
|
std::copy(eigen_c.data(), eigen_c.data() + a.size()[0] * b.size()[1], result->data_ptr());
|
|
|
#elif defined TRAPH_BUILD_MKL
|
|
|
@@ -172,7 +174,7 @@ namespace traph
|
|
|
Eigen::Map<const Eigen::Matrix<f64, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> eigen_a(a.data_ptr() + a.offset(), a.size()[0], a.size()[1]);
|
|
|
Eigen::Map<const Eigen::Matrix<f64, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> eigen_b(b.data_ptr() + b.offset(), b.size()[0], b.size()[1]);
|
|
|
|
|
|
- Eigen::Matrix<f64, Eigen::Dynamic, Eigen::Dynamic> eigen_c = eigen_a * eigen_b;
|
|
|
+ Eigen::Matrix<f64, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> eigen_c = eigen_a * eigen_b;
|
|
|
// copy to result
|
|
|
std::copy(eigen_c.data(), eigen_c.data() + a.size()[0] * b.size()[1], result->data_ptr());
|
|
|
#elif defined TRAPH_BUILD_MKL
|