JasonWang 7 rokov pred
rodič
commit
de721ed71b

+ 6 - 1
traph/include/traph/core/tensor.h

@@ -16,11 +16,11 @@ namespace traph
     {
     public:
         virtual T* data_ptr() = 0;
+        virtual const T* data_ptr() const = 0;
         virtual size_type element_size() const = 0;
         virtual void fill_(T v) = 0;
         virtual void resize_(idx_type size) = 0;
         virtual idx_type size() const = 0;
-
     };
 
     template<class T>
@@ -28,6 +28,7 @@ namespace traph
     {
     public:
         virtual T* data_ptr() = 0;
+        virtual const T* data_ptr() const = 0;
         virtual size_type element_size() const = 0;
         virtual void fill_(T v) = 0;
         virtual void resize_(idx_type size) = 0;
@@ -60,6 +61,7 @@ namespace traph
 		virtual DimVector stride() const = 0;
 		virtual idx_type stride(idx_type i) const = 0;
         virtual TensorInterfacePtr sum() const = 0;
+        virtual std::string to_string() const = 0;
     };
 
     using TensorInterfacePtr = std::shared_ptr<TensorInterface>;
@@ -86,6 +88,8 @@ namespace traph
         virtual void apply_(std::function<T(T)> f) = 0;
         virtual void cos_() = 0;
         virtual std::shared_ptr<TensorBase<f32>> create_grad() = 0;
+        virtual T* data_ptr() = 0;
+        virtual const T* data_ptr() const = 0;
         virtual device_id device() = 0;
         virtual void fill_(T value) = 0;
         virtual T item() const = 0;
@@ -103,6 +107,7 @@ namespace traph
 		virtual DimVector stride() const = 0;
 		virtual idx_type stride(idx_type i) const = 0;
         virtual TensorInterfacePtr sum() const = 0;
+        virtual std::string to_string() const = 0;
     };
 
     template<class T>

+ 2 - 0
traph/include/traph/core/variable.h

@@ -19,6 +19,7 @@ namespace traph
 
     public:
         virtual void backward() = 0;
+        virtual TensorInterfacePtr data() = 0;
         virtual device_id device() = 0;
         virtual TensorBasePtr<f32> grad() = 0;
         virtual std::shared_ptr<OpBase> grad_fn() = 0;
@@ -55,6 +56,7 @@ namespace traph
         using ByteVariableBase = VariableBase<u8>;
     public:
         virtual void backward() = 0;
+        virtual TensorInterfacePtr data() = 0;
         virtual device_id device() = 0;
         virtual void fill_(T value) = 0;
         virtual TensorBasePtr<f32> grad() = 0;

+ 7 - 0
traph/include/traph/nn/variable.h

@@ -50,6 +50,7 @@ namespace traph
 		friend std::shared_ptr<Variable<T>> sum(std::shared_ptr<Variable<T>> input);
 
         virtual void backward() override;
+        virtual TensorInterfacePtr data() override;
         virtual device_id device() override;
         virtual void fill_(T value) override;
         virtual TensorBasePtr<f32> grad() override;
@@ -175,6 +176,12 @@ namespace traph
 
 	}
 
+    template<typename T>
+    TensorInterfacePtr Variable<T>::data()
+    {
+        return std::dynamic_pointer_cast<TensorInterface>(_data);
+    }
+
 	template<typename T>
 	device_id Variable<T>::device()
 	{

+ 52 - 9
traph/include/traph/tensor/tensor.h

@@ -65,6 +65,7 @@ namespace traph
         }
 
         virtual T* data_ptr() override {return data.get();}
+        virtual const T* data_ptr() const override {return data.get();}
         virtual idx_type size() const override {return len;}
         virtual size_type element_size() const override {return sizeof(T);}
 
@@ -136,6 +137,8 @@ namespace traph
         virtual void apply_(std::function<T(T)> f) override;
         virtual void cos_() override;
         virtual std::shared_ptr<TensorBase<f32>> create_grad() override;
+        virtual T* data_ptr() override;
+        virtual const T* data_ptr() const override;
         virtual device_id device() override;
         virtual void fill_(T value) override;
         virtual T item() const override;
@@ -153,6 +156,7 @@ namespace traph
 		virtual DimVector stride() const override;
 		virtual idx_type stride(idx_type i) const override;
         virtual TensorInterfacePtr sum() const override;
+        virtual std::string to_string() const override;
     };
 
     using DoubleTensor = Tensor<f64>;
@@ -329,21 +333,20 @@ namespace traph
     {
 		// check tensor other type
 
-		// check broadcast
+		// check broadcast.shape = this.shape
 
 		// ok, get lhs, rhs
 		Tensor<T> * lhs = this;
 		Tensor<T> * rhs = dynamic_cast<Tensor<T> *>(other.get());
-		std::function<void(Tensor<T>&, Tensor<T> *, Tensor<T> *, idx_type, idx_type,idx_type, idx_type, idx_type)> add_impl =
-			[&](Tensor<T>& result, Tensor<T> * lhs, Tensor<T> * rhs, idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx, idx_type result_idx) {
+		std::function<void(Tensor<T> *, Tensor<T> *, idx_type, idx_type,idx_type, idx_type)> add_impl =
+			[&](Tensor<T> * lhs, Tensor<T> * rhs, idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
 
-			auto result_storage = std::dynamic_pointer_cast<TensorStorage<T>>(result.storage())->data_ptr();
 			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<T>>(lhs->storage())->data_ptr();
 			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<T>>(rhs->storage())->data_ptr();
 
 			if (lhs_dim < -(lhs->size().size()) && rhs_dim < -(rhs->size().size()))
 			{
-				result_storage[result_idx] = lhs_storage[lhs_idx] + rhs_storage[rhs_idx];
+				lhs_storage[lhs_idx] += rhs_storage[rhs_idx];
 				return;
 			}
 
@@ -353,17 +356,16 @@ namespace traph
 
 			for (idx_type i = 0; i < max_shape_size; ++i)
 			{
-				add_impl(result, lhs, rhs, lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx, result_idx);
+				add_impl(lhs, rhs, lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
 
 				if(lsh_shape_size > 1)
 					lhs_idx += lhs->stride(lhs_dim);
 				if (rsh_shape_size > 1)
 					rhs_idx += rhs->stride(rhs_dim);
-				result_idx += result.stride(std::min(lhs_dim, rhs_dim));
 			}
 		};
-		Tensor<T> result(broadcast_shape(*lhs, *rhs));
-		add_impl(result, lhs, rhs, -1, -1, lhs->offset(), rhs->offset(), 0);
+
+		add_impl(lhs, rhs, -1, -1, lhs->offset(), rhs->offset());
     }
     template<typename T>
     void Tensor<T>::apply_(std::function<T(T)> f)
@@ -381,6 +383,16 @@ namespace traph
         return std::shared_ptr<TensorBase<f32>>(new Tensor<f32>(_dimensions));
     }
     template<typename T>
+    T* Tensor<T>::data_ptr()
+    {
+        return _rep->data_ptr();
+    }
+    template<typename T>
+    const T* Tensor<T>::data_ptr() const
+    {
+        return _rep->data_ptr();
+    }
+    template<typename T>
     device_id Tensor<T>::device() { return 0; }
     template<typename T>
     void Tensor<T>::fill_(T value)
@@ -477,6 +489,37 @@ namespace traph
         result->_rep->data[0] = reduce_([](T a, T b)->T {return a + b; });
         return std::dynamic_pointer_cast<TensorInterface>(result);
     }
+    template<typename T>
+    std::string Tensor<T>::to_string() const
+    {
+        std::function<std::string(const Tensor<T>&, idx_type, idx_type)> to_string_impl =
+			[&](const Tensor<T>& t, idx_type dim, idx_type idx)->std::string {
+            std::string result;
+			if (dim == t.size().size())
+            {
+                result += std::to_string(t.data_ptr()[idx]);
+				return result;
+            }
+
+			for (idx_type i = 0; i < t.size(dim); ++i)
+			{
+				if (dim != t.size().size() - 1 && i != 0) result += ",\n";
+				if(dim != t.size().size() - 1)	result += "[";
+				result += to_string_impl(t, dim + 1, idx);
+				if (i != t.size(dim) - 1 && dim == t.size().size() - 1)
+					result += ",";
+				if (dim != t.size().size() - 1) result += "]";
+
+				idx += t.stride(dim);
+			}
+
+			return result;
+		};
+
+		std::string result;
+		result += "[" + to_string_impl(*this, 0, offset()) + "]";
+		return result;
+    }
 }
 
 #endif // !TRAPH_TENSOR

+ 2 - 1
traph/source/test/main.cpp

@@ -35,13 +35,14 @@ int main()
 	std::cout << b;
 	*/
 	// auto a = traph::Variable<traph::f32>({ 2, 3 });
+
 	auto a = traph::ones<traph::f32>({ 2,3 });
 	a->requires_grad_(true);
 	auto b = traph::sum<traph::f32>(a);
 
 	b->backward();
 
-	std::cout << b->item();
+	std::cout << a->grad()->to_string();
 
     return 0;
 }