|
|
@@ -6,6 +6,7 @@
|
|
|
#include <memory>
|
|
|
#include <functional>
|
|
|
#include <stdexcept>
|
|
|
+#include <algorithm>
|
|
|
|
|
|
|
|
|
#include<traph/core/type.h>
|
|
|
@@ -63,7 +64,7 @@ namespace traph
|
|
|
return *this;
|
|
|
}
|
|
|
|
|
|
- // size
|
|
|
+ virtual T* data_ptr() override {return data.get();}
|
|
|
virtual idx_type size() const override {return len;}
|
|
|
virtual size_type element_size() const override {return sizeof(T);}
|
|
|
|
|
|
@@ -147,8 +148,10 @@ namespace traph
|
|
|
virtual void resize_(const DimVector& dims) override;
|
|
|
virtual void sin_() override;
|
|
|
virtual DimVector size() const override;
|
|
|
- virtual StorageBase<T>& storage() const override;
|
|
|
+ virtual idx_type size(idx_type i) const override;
|
|
|
+ virtual std::shared_ptr<StorageBase<T>> storage() const override;
|
|
|
virtual DimVector stride() const override;
|
|
|
+ virtual idx_type stride(idx_type i) const override;
|
|
|
virtual TensorInterfacePtr sum() const override;
|
|
|
};
|
|
|
|
|
|
@@ -324,7 +327,43 @@ namespace traph
|
|
|
template<typename T>
|
|
|
void Tensor<T>::add_(TensorInterfacePtr other)
|
|
|
{
|
|
|
-
|
|
|
+ // check tensor other type
|
|
|
+
|
|
|
+ // check broadcast
|
|
|
+
|
|
|
+ // ok, get lhs, rhs
|
|
|
+ Tensor<T> * lhs = this;
|
|
|
+ Tensor<T> * rhs = dynamic_cast<Tensor<T> *>(other.get());
|
|
|
+ std::function<void(Tensor<T>&, Tensor<T> *, Tensor<T> *, idx_type, idx_type,idx_type, idx_type, idx_type)> add_impl =
|
|
|
+ [&](Tensor<T>& result, Tensor<T> * lhs, Tensor<T> * rhs, idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx, idx_type result_idx) {
|
|
|
+
|
|
|
+ auto result_storage = std::dynamic_pointer_cast<TensorStorage<T>>(result.storage())->data_ptr();
|
|
|
+ auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<T>>(lhs->storage())->data_ptr();
|
|
|
+ auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<T>>(rhs->storage())->data_ptr();
|
|
|
+
|
|
|
+ if (lhs_dim < -(lhs->size().size()) && rhs_dim < -(rhs->size().size()))
|
|
|
+ {
|
|
|
+ result_storage[result_idx] = lhs_storage[lhs_idx] + rhs_storage[rhs_idx];
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
|
|
|
+ idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
|
|
|
+ idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
|
|
|
+
|
|
|
+ for (idx_type i = 0; i < max_shape_size; ++i)
|
|
|
+ {
|
|
|
+ add_impl(result, lhs, rhs, lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx, result_idx);
|
|
|
+
|
|
|
+ if(lsh_shape_size > 1)
|
|
|
+ lhs_idx += lhs->stride(lhs_dim);
|
|
|
+ if (rsh_shape_size > 1)
|
|
|
+ rhs_idx += rhs->stride(rhs_dim);
|
|
|
+ result_idx += result.stride(std::min(lhs_dim, rhs_dim));
|
|
|
+ }
|
|
|
+ };
|
|
|
+ Tensor<T> result(broadcast_shape(*lhs, *rhs));
|
|
|
+ add_impl(result, lhs, rhs, -1, -1, lhs->offset(), rhs->offset(), 0);
|
|
|
}
|
|
|
template<typename T>
|
|
|
void Tensor<T>::apply_(std::function<T(T)> f)
|
|
|
@@ -402,10 +441,32 @@ namespace traph
|
|
|
}
|
|
|
template<typename T>
|
|
|
DimVector Tensor<T>::size() const { return _dimensions;}
|
|
|
- template<typename T>
|
|
|
- StorageBase<T>& Tensor<T>::storage() const { return *(_rep.get()); }
|
|
|
+ template<typename T>
|
|
|
+ idx_type Tensor<T>::size(idx_type i) const
|
|
|
+ {
|
|
|
+ auto shape_size = _dimensions.size();
|
|
|
+ if (i >= 0 && i < _dimensions.size())
|
|
|
+ return _dimensions[i];
|
|
|
+ else if (i <= -1 && i >= -_dimensions.size())
|
|
|
+ return _dimensions[shape_size + i];
|
|
|
+ else
|
|
|
+ throw std::runtime_error("Dimension out of range");
|
|
|
+ }
|
|
|
+ template<typename T>
|
|
|
+ std::shared_ptr<StorageBase<T>> Tensor<T>::storage() const { return _rep; }
|
|
|
template<typename T>
|
|
|
DimVector Tensor<T>::stride() const { return _strides; }
|
|
|
+ template<typename T>
|
|
|
+ idx_type Tensor<T>::stride(idx_type i) const
|
|
|
+ {
|
|
|
+ auto stride_size = _strides.size();
|
|
|
+ if (i >= 0 && i < _strides.size())
|
|
|
+ return _strides[i];
|
|
|
+ else if (i <= -1 && i >= -_strides.size())
|
|
|
+ return _strides[stride_size + i];
|
|
|
+ else
|
|
|
+ throw std::runtime_error("Stride out of range");
|
|
|
+ }
|
|
|
template<typename T>
|
|
|
TensorInterfacePtr Tensor<T>::sum() const
|
|
|
{
|