float_tensor.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. #include <traph/tensor/float_tensor.h>
  2. namespace traph
  3. {
  4. // definition
  5. // private
  6. void Tensor<f32>::auto_strides()
  7. {
  8. idx_type dim_num = _dimensions.size();
  9. _strides.resize(dim_num);
  10. idx_type stride = 1;
  11. if(_order == layout_type::column_major)
  12. {
  13. for (idx_type i = dim_num - 1; i >= 0; --i)
  14. {
  15. _strides[i] = stride;
  16. stride *= _dimensions[i];
  17. }
  18. }
  19. else
  20. {
  21. for (idx_type i = 0; i < dim_num; ++i)
  22. {
  23. _strides[i] = stride;
  24. stride *= _dimensions[i];
  25. }
  26. }
  27. }
  28. void Tensor<f32>::reduce_impl(f32& result, idx_type dim, idx_type idx, std::function<f32(f32,f32)> f) const
  29. {
  30. idx_type dim_size = _dimensions.size();
  31. idx_type step_len = _strides[dim];
  32. idx_type step_num = _dimensions[dim];
  33. for(idx_type i = 0; i < step_num; ++i)
  34. {
  35. if(dim == dim_size - 1)
  36. result = f(result, _rep->data[idx]);
  37. else
  38. reduce_impl(result, dim + 1, idx, f);
  39. idx += step_len;
  40. }
  41. }
  42. f32 Tensor<f32>::reduce_dim_kernel(idx_type begin, idx_type step_len, idx_type step_num, std::function<f32(f32,f32)> f) const
  43. {
  44. f32 result{};
  45. for(idx_type i = 0; i < step_num; ++i)
  46. {
  47. result = f(result, _rep->data[begin]);
  48. begin += step_len;
  49. }
  50. return result;
  51. }
  52. void Tensor<f32>::reduce_dim_impl(Tensor<f32>& result, idx_type dim, idx_type reduce_dim,
  53. idx_type this_idx, idx_type result_idx,
  54. std::function<f32(f32,f32)> f) const
  55. {
  56. idx_type dim_size = _dimensions.size();
  57. if(dim == dim_size)
  58. {
  59. result._rep->data[result_idx] =
  60. reduce_dim_kernel(this_idx, _strides[reduce_dim], _dimensions[reduce_dim], f);
  61. return;
  62. }
  63. if(dim == reduce_dim)
  64. {
  65. reduce_dim_impl(result, dim + 1, reduce_dim, this_idx,result_idx, f);
  66. }
  67. else
  68. {
  69. for(idx_type i = 0; i < _dimensions[dim]; ++i)
  70. {
  71. reduce_dim_impl(result, dim + 1, reduce_dim, this_idx,result_idx, f);
  72. this_idx += _strides[dim];
  73. result_idx += result._strides[dim];
  74. }
  75. }
  76. }
  77. // public
  78. Tensor<f32>::Tensor()
  79. :_rep(new TensorStorage<f32>),
  80. _dimensions(), _offset(0), _strides(), _order(layout_type::column_major)
  81. {
  82. }
  83. Tensor<f32>::Tensor(const DimVector& dimensions)
  84. :_rep(new TensorStorage<f32>),
  85. _dimensions(dimensions), _offset(0), _strides(), _order(layout_type::column_major)
  86. {
  87. auto_strides();
  88. _rep->resize_(_dimensions.flat_size());
  89. }
  90. Tensor<f32>::Tensor(const DimVector& dimensions, layout_type order)
  91. :_rep(new TensorStorage<f32>),
  92. _dimensions(dimensions), _offset(0), _strides(), _order(order)
  93. {
  94. auto_strides();
  95. _rep->resize_(_dimensions.flat_size());
  96. }
  97. Tensor<f32>::Tensor(const DimVector& dimensions, const DimVector& strides)
  98. :_rep(new TensorStorage<f32>),
  99. _dimensions(dimensions), _offset(0), _strides(strides), _order(layout_type::column_major)
  100. {
  101. auto_strides();
  102. _rep->resize_(_dimensions.flat_size());
  103. }
  104. Tensor<f32>::Tensor(const DimVector& dimensions, const DimVector& strides, layout_type order)
  105. :_rep(new TensorStorage<f32>),
  106. _dimensions(dimensions), _offset(0), _strides(strides), _order(order)
  107. {
  108. auto_strides();
  109. _rep->resize_(_dimensions.flat_size());
  110. }
  111. Tensor<f32>::Tensor(const f32& t)
  112. :_rep(new TensorStorage<f32>),
  113. _dimensions(), _offset(0), _strides()
  114. {
  115. _dimensions.resize(1);
  116. auto_strides();
  117. }
  118. void Tensor<f32>::add_(TensorInterfacePtr other)
  119. {
  120. // check tensor other type
  121. if(other->dtype() != DataType::FLOAT)
  122. throw std::runtime_error("expected type float tensor");
  123. // check broadcast.shape = this.shape
  124. auto shape = broadcast_shape(this->size(), other->size());
  125. if(shape != this->size())
  126. throw std::runtime_error("The size of tensor a must match the size of tensor b");
  127. // ok, get lhs, rhs
  128. Tensor<f32> * lhs = this;
  129. Tensor<f32> * rhs = dynamic_cast<Tensor<f32> *>(other.get());
  130. std::function<void(idx_type, idx_type, idx_type, idx_type)> add_impl =
  131. [&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
  132. auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
  133. auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
  134. idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
  135. idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
  136. idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
  137. for (idx_type i = 0; i < max_shape_size; ++i)
  138. {
  139. if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
  140. {
  141. lhs_storage[lhs_idx] += rhs_storage[rhs_idx];
  142. }
  143. else
  144. {
  145. add_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
  146. }
  147. if(lsh_shape_size > 1)
  148. lhs_idx += lhs->stride(lhs_dim);
  149. if (rsh_shape_size > 1)
  150. rhs_idx += rhs->stride(rhs_dim);
  151. }
  152. };
  153. add_impl(-1, -1, lhs->offset(), rhs->offset());
  154. }
  155. void Tensor<f32>::apply_(std::function<f32(f32)> f)
  156. {
  157. // sort stride for cache optimization
  158. DimVector cloned_stride(_strides);
  159. DimVector sorted_stride(_strides.size());
  160. for(int i = 0; i<_strides.size(); ++i)
  161. sorted_stride[i] = i;
  162. for (int i = 0; i < cloned_stride.size() - 1; i++)
  163. for (int j = 0; j < cloned_stride.size() - 1 - i; j++)
  164. if (cloned_stride[j] < cloned_stride[j + 1])
  165. {
  166. std::swap(cloned_stride[j], cloned_stride[j+1]);
  167. std::swap(sorted_stride[j], sorted_stride[j+1]);
  168. }
  169. std::function<void(idx_type, idx_type, std::function<f32(f32)>)> apply_impl =
  170. [&](idx_type dim_idx, idx_type idx, std::function<f32(f32)> f){
  171. idx_type dim = sorted_stride[dim_idx];
  172. idx_type dim_size = _dimensions.size();
  173. idx_type step_len = _strides[dim];
  174. idx_type step_num = _dimensions[dim];
  175. for(idx_type i = 0; i < step_num; ++i)
  176. {
  177. if(dim_idx == dim_size - 1)
  178. _rep->data[idx] = f(_rep->data[idx]);
  179. else
  180. apply_impl(dim_idx + 1, idx, f);
  181. idx += step_len;
  182. }
  183. };
  184. if(_dimensions.size() > 0)
  185. apply_impl(0, _offset, f);
  186. }
  187. TensorInterfacePtr Tensor<f32>::clone() const
  188. {
  189. std::shared_ptr<Tensor<f32>> cloned_tensor(new Tensor<f32>);
  190. cloned_tensor->_rep = std::dynamic_pointer_cast<TensorStorage<f32>>(_rep->clone());
  191. cloned_tensor->_dimensions = _dimensions;
  192. cloned_tensor->_offset = _offset;
  193. cloned_tensor->_strides = _strides;
  194. cloned_tensor->_order = _order;
  195. return cloned_tensor;
  196. }
  197. void Tensor<f32>::cos_()
  198. {
  199. apply_([](f32 a)->f32 {return std::cos(a); });
  200. }
  201. std::shared_ptr<TensorBase<f32>> Tensor<f32>::create_grad()
  202. {
  203. return std::shared_ptr<TensorBase<f32>>(new Tensor<f32>(_dimensions));
  204. }
  205. f32* Tensor<f32>::data_ptr()
  206. {
  207. return _rep->data_ptr();
  208. }
  209. const f32* Tensor<f32>::data_ptr() const
  210. {
  211. return _rep->data_ptr();
  212. }
  213. device_id Tensor<f32>::device() { return 0; }
  214. DataType Tensor<f32>::dtype() const
  215. {
  216. return DataType::FLOAT;
  217. }
  218. bool Tensor<f32>::equal(std::shared_ptr<TensorInterface> other) const
  219. {
  220. if(other->platform() != this->platform())
  221. throw std::runtime_error("equal: Two tensors must be the same platform");
  222. if(other->dtype() != this->dtype())
  223. return false;
  224. if(other->size() != this->size())
  225. return false;
  226. std::shared_ptr<Tensor<f32>> other_ptr = std::dynamic_pointer_cast<Tensor<f32>>(other);
  227. std::function<bool(idx_type, f32*, f32*)> equal_impl =
  228. [&](idx_type dim, f32* lhs_idx, f32* rhs_idx){
  229. idx_type dim_size = _dimensions.size();
  230. for(idx_type i = 0; i < _dimensions[dim]; ++i)
  231. {
  232. if(dim == dim - 1)
  233. {
  234. if(*lhs_idx != *rhs_idx) return false;
  235. }
  236. else
  237. {
  238. if(!equal_impl(dim + 1, lhs_idx, rhs_idx)) return false;
  239. }
  240. lhs_idx += _strides[dim];
  241. rhs_idx += other_ptr->stride(dim);
  242. }
  243. return true;
  244. };
  245. return equal_impl(0, _rep->data_ptr() + _offset, other_ptr->data_ptr() + other_ptr->offset());
  246. }
  247. std::shared_ptr<TensorInterface> Tensor<f32>::inverse() const
  248. {
  249. return std::dynamic_pointer_cast<TensorInterface>(inverse_impl(*this));
  250. }
  251. void Tensor<f32>::fill_(f32 value)
  252. {
  253. apply_([&value](f32 a)->f32 {return value; });
  254. }
  255. f32 Tensor<f32>::item() const
  256. {
  257. if(_dimensions.flat_size() == 1)
  258. {
  259. return _rep->data[_offset];
  260. }
  261. else
  262. {
  263. throw std::runtime_error("item: only one element tensors can be converted to scalars");
  264. }
  265. }
  266. std::shared_ptr<TensorInterface> Tensor<f32>::matmul(std::shared_ptr<TensorInterface> mat) const
  267. {
  268. auto right_matrix = std::dynamic_pointer_cast<Tensor<f32>>(mat);
  269. return matmul_impl(*this, *right_matrix);
  270. }
  271. TensorInterfacePtr Tensor<f32>::mean() const
  272. {
  273. DimVector d(1);
  274. d[0] = 1;
  275. TensorPtr<f32> result(new Tensor<f32>(d));
  276. auto flat_size = _dimensions.flat_size();
  277. result->_rep->data[0] = reduce([](f32 a, f32 b)->f32 {return a + b; });
  278. result->_rep->data[0] /= flat_size;
  279. return std::dynamic_pointer_cast<TensorInterface>(result);
  280. }
  281. void Tensor<f32>::mul_(f32 value)
  282. {
  283. apply_([value](f32 a)->f32 {return a*value; });
  284. }
  285. void Tensor<f32>::mul_(std::shared_ptr<TensorInterface> other)
  286. {
  287. // check tensor other type
  288. if(other->dtype() != DataType::FLOAT)
  289. throw std::runtime_error("expected type float tensor");
  290. // check broadcast.shape = this.shape
  291. auto shape = broadcast_shape(this->size(), other->size());
  292. if(shape != this->size())
  293. throw std::runtime_error("The size of tensor a must match the size of tensor b");
  294. // ok, get lhs, rhs
  295. Tensor<f32> * lhs = this;
  296. Tensor<f32> * rhs = dynamic_cast<Tensor<f32> *>(other.get());
  297. std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
  298. [&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
  299. auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
  300. auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
  301. idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
  302. idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
  303. idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
  304. for (idx_type i = 0; i < max_shape_size; ++i)
  305. {
  306. if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
  307. {
  308. lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
  309. }
  310. else
  311. {
  312. mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
  313. }
  314. if(lsh_shape_size > 1)
  315. lhs_idx += lhs->stride(lhs_dim);
  316. if (rsh_shape_size > 1)
  317. rhs_idx += rhs->stride(rhs_dim);
  318. }
  319. };
  320. mul_impl(-1, -1, lhs->offset(), rhs->offset());
  321. }
  322. idx_type Tensor<f32>::ndimension() const
  323. {
  324. return _dimensions.size();
  325. }
  326. void Tensor<f32>::neg_()
  327. {
  328. apply_([](f32 a)->f32 {return -a; });
  329. }
  330. idx_type Tensor<f32>::offset() const { return _offset; }
  331. layout_type Tensor<f32>::order() const { return _order; }
  332. std::shared_ptr<TensorInterface> Tensor<f32>::permute(const DimVector& dims) const
  333. {
  334. // check dims
  335. if(dims.size() != _strides.size())
  336. throw std::runtime_error("permute dimension must have the same size");
  337. std::vector<int> check_vec(dims.size(), 0);
  338. for(int i = 0; i < dims.size();++i)
  339. if(dims[i] >= 0 && dims[i] < dims.size())
  340. check_vec[dims[i]] = 1;
  341. else
  342. throw std::runtime_error("permute dimension must in ndimension range");
  343. for(int i = 0; i < check_vec.size();++i)
  344. {
  345. if(check_vec[i] != 1)
  346. throw std::runtime_error("permute dimension error");
  347. }
  348. // permute
  349. std::shared_ptr<Tensor<f32>> result(new Tensor<f32>);
  350. result->_rep = _rep;
  351. result->_dimensions = _dimensions;
  352. result->_offset = _offset;
  353. result->_strides = _strides;
  354. result->_order = _order;
  355. for(int i=0; i<dims.size(); ++i)
  356. {
  357. result->_dimensions[i] = _dimensions[dims[i]];
  358. result->_strides[i] = _strides[dims[i]];
  359. }
  360. return result;
  361. }
  362. PlatformType Tensor<f32>::platform() const { return PlatformType::CPU; }
  363. void Tensor<f32>::pow_(f32 exp)
  364. {
  365. apply_([&exp](f32 a)->f32 {return std::pow(a, exp); });
  366. }
  367. f32 Tensor<f32>::reduce(std::function<f32(f32, f32)> f) const
  368. {
  369. f32 result{};
  370. reduce_impl(result, 0, _offset, f);
  371. return result;
  372. }
  373. TensorInterfacePtr Tensor<f32>::reduce_dim(idx_type dim, std::function<f32(f32, f32)> f) const
  374. {
  375. DimVector reduced_dim = _dimensions;
  376. reduced_dim.erase(dim); // check dim?
  377. TensorBasePtr<f32> result(new Tensor<f32>(reduced_dim));
  378. TensorPtr<f32> raw_result = std::dynamic_pointer_cast<Tensor<f32>>(result);
  379. reduce_dim_impl(*(raw_result.get()), 0, dim, _offset, raw_result->_offset, f);
  380. return std::dynamic_pointer_cast<TensorInterface>(result);
  381. }
  382. void Tensor<f32>::reshape_(const DimVector& dims)
  383. {
  384. }
  385. void Tensor<f32>::resize_(const DimVector& dims)
  386. {
  387. _dimensions = dims;
  388. _rep->resize_(dims.flat_size());
  389. auto_strides();
  390. }
  391. std::shared_ptr<TensorInterface> Tensor<f32>::select(const SliceVector& slice) const
  392. {
  393. std::shared_ptr<Tensor<f32>> result(new Tensor<f32>);
  394. result->_rep = _rep;
  395. // dimension
  396. DimVector dim;
  397. std::fesetround(FE_TONEAREST);
  398. for (idx_type i = 0; i < slice.size(); ++i)
  399. {
  400. auto& each = slice[i];
  401. dim.push_back(
  402. std::lrint(std::ceil((each.end.value_or(_dimensions[i]) - each.start.value_or(0)) / (float)each.step.value_or(1)))
  403. );
  404. }
  405. result->_dimensions = dim;
  406. // offset
  407. idx_type new_offset = 1;
  408. for (idx_type i = 0; i < slice.size(); ++i)
  409. {
  410. new_offset *= _strides[i] * slice[i].start.value_or(0);
  411. }
  412. result->_offset = _offset + new_offset;
  413. // strides
  414. DimVector strides;
  415. for (idx_type i = 0; i < slice.size(); ++i)
  416. {
  417. strides.push_back(_strides[i] * slice[i].step.value_or(1));
  418. }
  419. result->_strides = strides;
  420. result->_order = _order;
  421. return std::dynamic_pointer_cast<TensorInterface>(result);
  422. }
  423. void Tensor<f32>::sin_()
  424. {
  425. apply_([](f32 a)->f32 {return std::sin(a); });
  426. }
  427. DimVector Tensor<f32>::size() const { return _dimensions;}
  428. idx_type Tensor<f32>::size(idx_type i) const
  429. {
  430. auto shape_size = _dimensions.size();
  431. if (i >= 0 && i < _dimensions.size())
  432. return _dimensions[i];
  433. else if (i <= -1 && i >= -_dimensions.size())
  434. return _dimensions[shape_size + i];
  435. else
  436. throw std::runtime_error("Dimension out of range");
  437. }
  438. std::shared_ptr<StorageBase<f32>> Tensor<f32>::storage() const { return _rep; }
  439. DimVector Tensor<f32>::stride() const { return _strides; }
  440. idx_type Tensor<f32>::stride(idx_type i) const
  441. {
  442. auto stride_size = _strides.size();
  443. if (i >= 0 && i < _strides.size())
  444. return _strides[i];
  445. else if (i <= -1 && i >= -_strides.size())
  446. return _strides[stride_size + i];
  447. else
  448. throw std::runtime_error("Stride out of range");
  449. }
  450. void Tensor<f32>::sub_(std::shared_ptr<TensorInterface> other)
  451. {
  452. Tensor<f32> * lhs = this;
  453. Tensor<f32> * rhs = dynamic_cast<Tensor<f32> *>(other.get());
  454. std::function<void(Tensor<f32> *, Tensor<f32> *, idx_type, idx_type,idx_type, idx_type)> sub_impl =
  455. [&](Tensor<f32> * lhs, Tensor<f32> * rhs, idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
  456. auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
  457. auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
  458. if (lhs_dim < -(lhs->size().size()) && rhs_dim < -(rhs->size().size()))
  459. {
  460. lhs_storage[lhs_idx] -= rhs_storage[rhs_idx];
  461. return;
  462. }
  463. idx_type lhs_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
  464. idx_type rhs_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
  465. idx_type max_shape_size = std::max(lhs_shape_size, rhs_shape_size);
  466. for (idx_type i = 0; i < max_shape_size; ++i)
  467. {
  468. sub_impl(lhs, rhs, lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
  469. if(lhs_shape_size > 1)
  470. lhs_idx += lhs->stride(lhs_dim);
  471. if (rhs_shape_size > 1)
  472. rhs_idx += rhs->stride(rhs_dim);
  473. }
  474. };
  475. sub_impl(lhs, rhs, -1, -1, lhs->offset(), rhs->offset());
  476. }
  477. TensorInterfacePtr Tensor<f32>::sum() const
  478. {
  479. DimVector d(1);
  480. d[0] = 1;
  481. TensorPtr<f32> result(new Tensor<f32>(d));
  482. result->_rep->data[0] = reduce([](f32 a, f32 b)->f32 {return a + b; });
  483. return std::dynamic_pointer_cast<TensorInterface>(result);
  484. }
  485. std::string Tensor<f32>::to_string() const
  486. {
  487. std::function<std::string(const Tensor<f32>&, idx_type, idx_type)> to_string_impl =
  488. [&](const Tensor<f32>& t, idx_type dim, idx_type idx)->std::string {
  489. std::string result;
  490. if (dim == t.size().size())
  491. {
  492. result += std::to_string(t.data_ptr()[idx]);
  493. return result;
  494. }
  495. for (idx_type i = 0; i < t.size(dim); ++i)
  496. {
  497. if (dim != t.size().size() - 1 && i != 0) result += ",\n";
  498. if(dim != t.size().size() - 1) result += "[";
  499. result += to_string_impl(t, dim + 1, idx);
  500. if (i != t.size(dim) - 1 && dim == t.size().size() - 1)
  501. result += ",";
  502. if (dim != t.size().size() - 1) result += "]";
  503. idx += t.stride(dim);
  504. }
  505. return result;
  506. };
  507. std::string result;
  508. result += "[" + to_string_impl(*this, 0, offset()) + "]";
  509. return result;
  510. }
  511. void Tensor<f32>::transpose_(idx_type dim0, idx_type dim1)
  512. {
  513. if(dim0 != dim1 &&
  514. _dimensions.in_range(dim0) &&
  515. _dimensions.in_range(dim1))
  516. {
  517. std::swap(_dimensions[dim0], _dimensions[dim1]);
  518. std::swap(_strides[dim0], _strides[dim1]);
  519. }
  520. }
  521. std::shared_ptr<TensorInterface> Tensor<f32>::transpose(idx_type dim0, idx_type dim1)
  522. {
  523. std::shared_ptr<Tensor<f32>> result(new Tensor<f32>);
  524. result->_rep = _rep;
  525. result->_dimensions = _dimensions;
  526. result->_offset = _offset;
  527. result->_strides = _strides;
  528. result->_order = _order;
  529. result->transpose_(dim0, dim1);
  530. return result;
  531. }
  532. }