JasonWang 7 lat temu
rodzic
commit
76a4c58f82

+ 30 - 10
traph/include/traph/core/tensor.h

@@ -15,6 +15,7 @@ namespace traph
     class StorageBase
     {
     public:
+        virtual T* data_ptr() = 0;
         virtual size_type element_size() const = 0;
         virtual void fill_(T v) = 0;
         virtual void resize_(idx_type size) = 0;
@@ -26,6 +27,7 @@ namespace traph
     class ContiguousStorageBase: public StorageBase<T>
     {
     public:
+        virtual T* data_ptr() = 0;
         virtual size_type element_size() const = 0;
         virtual void fill_(T v) = 0;
         virtual void resize_(idx_type size) = 0;
@@ -54,7 +56,9 @@ namespace traph
         virtual void resize_(const DimVector& dims) = 0;
         virtual void sin_() = 0;
 		virtual DimVector size() const = 0;
+		virtual idx_type size(idx_type i) const = 0;
 		virtual DimVector stride() const = 0;
+		virtual idx_type stride(idx_type i) const = 0;
         virtual TensorInterfacePtr sum() const = 0;
     };
 
@@ -94,8 +98,10 @@ namespace traph
         virtual void resize_(const DimVector& dims) = 0;
         virtual void sin_() = 0;
 		virtual DimVector size() const = 0;
-        virtual StorageBase<T>& storage() const = 0;
+		virtual idx_type size(idx_type i) const = 0;
+        virtual std::shared_ptr<StorageBase<T>> storage() const = 0;
 		virtual DimVector stride() const = 0;
+		virtual idx_type stride(idx_type i) const = 0;
         virtual TensorInterfacePtr sum() const = 0;
     };
 
@@ -115,19 +121,33 @@ namespace traph
             return false;
 
         idx_type min = std::min(lhs_dim.size(), rhs_dim.size());
-        for(idx_type i = 0; i<min;++i)
-        {
-            if(lhs_dim[i] != rhs_dim[i] &&
-                lhs_dim[i] != 1 &&
-                rhs_dim[i] != 1)
-            {
-                return false;
-            }
-        }
+		for (idx_type i = -1; i >= -min; --i)
+			if (lhs.size(i) != rhs.size(i) && lhs.size(i) != 1 && rhs.size(i) != 1)
+				return false;    
 
         return true;
     }
 
+	template<class T>
+	DimVector broadcast_shape(const TensorBase<T> &lhs, const TensorBase<T> & rhs)
+	{
+		bool is_broadcastable = broadcastable(lhs, rhs);
+		if (!is_broadcastable)
+			throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		DimVector lhs_dim = lhs.size();
+		DimVector rhs_dim = rhs.size();
+		auto max_size = std::max(lhs_dim.size(), rhs_dim.size());
+		DimVector result_dim(max_size);
+
+		for (idx_type i = -1; i >= -max_size; --i)
+		{
+			idx_type lhs_size = i >= -lhs_dim.size() ? lhs.size(i) : 1;
+			idx_type rhs_size = i >= -rhs_dim.size() ? rhs.size(i) : 1;
+			result_dim[max_size + i] = std::max(lhs_size, rhs_size);
+		}
+		return result_dim;
+	}
+
     template<class T>
     bool strict_same_shape(const TensorBase<T> &lhs, const TensorBase<T> & rhs)
     {

+ 1 - 1
traph/include/traph/core/variable.h

@@ -69,7 +69,7 @@ namespace traph
         virtual void reshape_(const DimVector& dims) = 0;
         virtual void resize_(const DimVector& dims) = 0;
 		virtual DimVector size() const = 0;
-        virtual StorageBase<T>& storage() const = 0;
+        virtual std::shared_ptr<StorageBase<T>> storage() const = 0;
 		virtual DimVector stride() const = 0;
     };
 

+ 3 - 3
traph/include/traph/nn/variable.h

@@ -64,7 +64,7 @@ namespace traph
         virtual void reshape_(const DimVector& dims) override;
         virtual void resize_(const DimVector& dims) override;
 		virtual DimVector size() const override;
-        virtual StorageBase<T>& storage() const override;
+        virtual std::shared_ptr<StorageBase<T>> storage() const override;
 		virtual DimVector stride() const override;
     };
 
@@ -160,7 +160,7 @@ namespace traph
 		_grad->fill_(1);
 
 		std::vector<VariableInterface*> sorted_node = Executor::topology_sort(dynamic_cast<VariableInterface*>(this));
-		for (int i = sorted_node.size() - 1; i >= 0; --i)
+		for (int i = static_cast<int>(sorted_node.size()) - 1; i >= 0; --i)
 		{
 			VariableInterface* cur_node = sorted_node[i];
 			if (cur_node->is_leaf()) continue;
@@ -269,7 +269,7 @@ namespace traph
 	}
 
 	template<typename T>
-	StorageBase<T>& Variable<T>::storage() const
+	std::shared_ptr<StorageBase<T>> Variable<T>::storage() const
 	{
 		return _data->storage();
 	}

+ 66 - 5
traph/include/traph/tensor/tensor.h

@@ -6,6 +6,7 @@
 #include <memory>
 #include <functional>
 #include <stdexcept>
+#include <algorithm>
 
 
 #include<traph/core/type.h>
@@ -63,7 +64,7 @@ namespace traph
             return *this;
         }
 
-        // size
+        virtual T* data_ptr() override {return data.get();}
         virtual idx_type size() const override {return len;}
         virtual size_type element_size() const override {return sizeof(T);}
 
@@ -147,8 +148,10 @@ namespace traph
         virtual void resize_(const DimVector& dims) override;
         virtual void sin_() override;
 		virtual DimVector size() const override;
-        virtual StorageBase<T>& storage() const override;
+		virtual idx_type size(idx_type i) const override;
+        virtual std::shared_ptr<StorageBase<T>> storage() const override;
 		virtual DimVector stride() const override;
+		virtual idx_type stride(idx_type i) const override;
         virtual TensorInterfacePtr sum() const override;
     };
 
@@ -324,7 +327,43 @@ namespace traph
     template<typename T>
     void Tensor<T>::add_(TensorInterfacePtr other)
     {
-        
+		// check tensor other type
+
+		// check broadcast
+
+		// ok, get lhs, rhs
+		Tensor<T> * lhs = this;
+		Tensor<T> * rhs = dynamic_cast<Tensor<T> *>(other.get());
+		std::function<void(Tensor<T>&, Tensor<T> *, Tensor<T> *, idx_type, idx_type,idx_type, idx_type, idx_type)> add_impl =
+			[&](Tensor<T>& result, Tensor<T> * lhs, Tensor<T> * rhs, idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx, idx_type result_idx) {
+
+			auto result_storage = std::dynamic_pointer_cast<TensorStorage<T>>(result.storage())->data_ptr();
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<T>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<T>>(rhs->storage())->data_ptr();
+
+			if (lhs_dim < -(lhs->size().size()) && rhs_dim < -(rhs->size().size()))
+			{
+				result_storage[result_idx] = lhs_storage[lhs_idx] + rhs_storage[rhs_idx];
+				return;
+			}
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+				add_impl(result, lhs, rhs, lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx, result_idx);
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+				result_idx += result.stride(std::min(lhs_dim, rhs_dim));
+			}
+		};
+		Tensor<T> result(broadcast_shape(*lhs, *rhs));
+		add_impl(result, lhs, rhs, -1, -1, lhs->offset(), rhs->offset(), 0);
     }
     template<typename T>
     void Tensor<T>::apply_(std::function<T(T)> f)
@@ -402,10 +441,32 @@ namespace traph
     }
     template<typename T>
     DimVector Tensor<T>::size() const { return _dimensions;}
-    template<typename T>
-    StorageBase<T>& Tensor<T>::storage() const { return *(_rep.get()); }
+	template<typename T>
+	idx_type Tensor<T>::size(idx_type i) const
+	{ 
+		auto shape_size = _dimensions.size();
+		if (i >= 0 && i < _dimensions.size())
+			return _dimensions[i];
+		else if (i <= -1 && i >= -_dimensions.size())
+			return _dimensions[shape_size + i];
+		else
+			throw std::runtime_error("Dimension out of range");
+	}
+    template<typename T>
+	std::shared_ptr<StorageBase<T>>  Tensor<T>::storage() const { return _rep; }
     template<typename T>
     DimVector Tensor<T>::stride() const { return _strides; }
+	template<typename T>
+	idx_type Tensor<T>::stride(idx_type i) const
+	{
+		auto stride_size = _strides.size();
+		if (i >= 0 && i < _strides.size())
+			return _strides[i];
+		else if (i <= -1 && i >= -_strides.size())
+			return _strides[stride_size + i];
+		else
+			throw std::runtime_error("Stride out of range");
+	}
     template<typename T>
     TensorInterfacePtr Tensor<T>::sum() const
     {