JasonWang 7 ani în urmă
părinte
comite
c3f8bace94

+ 21 - 0
traph/include/traph/core/operation.h

@@ -58,6 +58,27 @@ namespace traph
             return {output_grad};
         }
     };
+
+	class AddOp : public OpBase
+	{
+	public:
+		virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
+		{
+			assert(inputs.size() == 2);
+
+			TensorInterfacePtr left_input = inputs[0];
+			TensorInterfacePtr right_input = inputs[1];
+			TensorInterfacePtr result = left_input->clone();
+            result->add_(right_input);
+
+			return result;
+		}
+
+		virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
+		{
+			return { output_grad, output_grad };
+		}
+	};
 }
 
 #endif

+ 4 - 0
traph/include/traph/core/tensor.h

@@ -15,6 +15,7 @@ namespace traph
     class StorageBase
     {
     public:
+        virtual std::shared_ptr<StorageBase<T>> clone() const = 0;
         virtual T* data_ptr() = 0;
         virtual const T* data_ptr() const = 0;
         virtual size_type element_size() const = 0;
@@ -27,6 +28,7 @@ namespace traph
     class ContiguousStorageBase: public StorageBase<T>
     {
     public:
+        virtual std::shared_ptr<StorageBase<T>> clone() const = 0;
         virtual T* data_ptr() = 0;
         virtual const T* data_ptr() const = 0;
         virtual size_type element_size() const = 0;
@@ -47,6 +49,7 @@ namespace traph
 
     public:
         virtual void add_(TensorInterfacePtr other) = 0;
+        virtual TensorInterfacePtr clone() const = 0;
         virtual void cos_() = 0;
         virtual std::shared_ptr<TensorBase<f32>> create_grad() = 0;
         virtual device_id device() = 0;
@@ -86,6 +89,7 @@ namespace traph
     public:
         virtual void add_(TensorInterfacePtr other) = 0;
         virtual void apply_(std::function<T(T)> f) = 0;
+        virtual TensorInterfacePtr clone() const = 0;
         virtual void cos_() = 0;
         virtual std::shared_ptr<TensorBase<f32>> create_grad() = 0;
         virtual T* data_ptr() = 0;

+ 11 - 0
traph/include/traph/core/type.h

@@ -33,6 +33,17 @@ namespace traph
         vulkan,
         opengl
     };
+
+    enum dtype
+    {
+        BYTE,
+        CHAR,
+        SHORT,
+        INT,
+        LONG,
+        FLOAT,
+        DOUBLE
+    };
 }
 
 #endif

+ 26 - 0
traph/include/traph/nn/arithmetic.h

@@ -40,6 +40,32 @@ namespace traph
         return result;
     }
 
+	template<class T>
+	VariablePtr<T> add(VariablePtr<T> left, VariablePtr<T> right)
+	{
+		VariablePtr<T> result(new Variable<T>);
+		std::shared_ptr<AddOp> op(new AddOp);
+		if (left->_requires_grad || right->_requires_grad)
+		{
+			std::vector<VariableInterfacePtr> result_inputs{ left, right };
+			result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ left->_data, right->_data }));
+			result->_grad = result->_data->create_grad();
+			result->_grad->fill_(0);
+			result->_requires_grad = true;
+			result->_leaf = false;
+			result->_grad_fn = op;
+			result->_inputs = result_inputs;
+		}
+		else
+		{
+			result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ left->_data, right->_data }));
+			result->_requires_grad = false;
+			result->_leaf = false;
+		}
+
+		return result;
+	}
+
 }
 
 #endif

+ 6 - 1
traph/include/traph/nn/variable.h

@@ -49,6 +49,9 @@ namespace traph
 		template<class T>
 		friend std::shared_ptr<Variable<T>> sum(std::shared_ptr<Variable<T>> input);
 
+		template<class T>
+		friend std::shared_ptr<Variable<T>> add(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
+
         virtual void backward() override;
         virtual TensorInterfacePtr data() override;
         virtual device_id device() override;
@@ -130,6 +133,7 @@ namespace traph
 			_requires_grad = true;
 
 			_grad = _data->create_grad();
+			_grad->fill_(0);
 		}
 	}
 
@@ -155,6 +159,7 @@ namespace traph
 
 	}
 
+	// fixme: remove no requires_grad
 	template<typename T>
 	void Variable<T>::backward()
 	{
@@ -167,7 +172,7 @@ namespace traph
 			if (cur_node->is_leaf()) continue;
 			std::vector<TensorBasePtr<f32>> back_grad = cur_node->grad_fn()->backward(cur_node->grad());
 
-			assert(back_grad.size() == _inputs.size());
+			assert(back_grad.size() == cur_node->inputs().size());
 			for (int i = 0; i < cur_node->inputs().size(); ++i)
 			{
 				cur_node->inputs()[i]->grad()->add_(back_grad[i]);

+ 28 - 8
traph/include/traph/tensor/tensor.h

@@ -64,6 +64,15 @@ namespace traph
             return *this;
         }
 
+        virtual std::shared_ptr<StorageBase<T>> clone() const override
+        {
+            std::shared_ptr<TensorStorage<T>> cloned_storage(new TensorStorage<T>);
+            cloned_storage->data = std::unique_ptr<T[]>(new T[len]);
+            std::memcpy(cloned_storage->data.get(), data.get(), len * sizeof(T));
+            cloned_storage->len = len;
+
+            return std::dynamic_pointer_cast<StorageBase<T>>(cloned_storage);
+        }
         virtual T* data_ptr() override {return data.get();}
         virtual const T* data_ptr() const override {return data.get();}
         virtual idx_type size() const override {return len;}
@@ -101,8 +110,6 @@ namespace traph
         idx_type _offset;
 		DimVector _strides;
         layout_type _order;
-
-        bool _requires_grad;
     public:
         using TensorPtr = std::shared_ptr<Tensor<T>>;
         using TensorRef = Tensor<T>&;
@@ -135,6 +142,7 @@ namespace traph
 
         virtual void add_(TensorInterfacePtr other) override;
         virtual void apply_(std::function<T(T)> f) override;
+        virtual TensorInterfacePtr clone() const override;
         virtual void cos_() override;
         virtual std::shared_ptr<TensorBase<f32>> create_grad() override;
         virtual T* data_ptr() override;
@@ -274,7 +282,7 @@ namespace traph
     template<typename T>
     Tensor<T>::Tensor()
         :_rep(new TensorStorage<T>),
-        _dimensions(), _offset(0), _strides(), _order(layout_type::column_major), _requires_grad(false)
+        _dimensions(), _offset(0), _strides(), _order(layout_type::column_major)
     {
     }
 
@@ -282,7 +290,7 @@ namespace traph
     template<typename T>
     Tensor<T>::Tensor(const DimVector& dimensions)
         :_rep(new TensorStorage<T>),
-        _dimensions(dimensions), _offset(0), _strides(), _order(layout_type::column_major), _requires_grad(false)
+        _dimensions(dimensions), _offset(0), _strides(), _order(layout_type::column_major)
     {
         auto_strides();
         
@@ -292,7 +300,7 @@ namespace traph
     template<typename T>
     Tensor<T>::Tensor(const DimVector& dimensions, layout_type order)
         :_rep(new TensorStorage<T>),
-        _dimensions(dimensions), _offset(0), _strides(), _order(order), _requires_grad(false)
+        _dimensions(dimensions), _offset(0), _strides(), _order(order)
     {
         auto_strides();
 
@@ -302,7 +310,7 @@ namespace traph
     template<typename T>
     Tensor<T>::Tensor(const DimVector& dimensions, const DimVector& strides)
         :_rep(new TensorStorage<T>),
-        _dimensions(dimensions), _offset(0), _strides(strides), _order(layout_type::column_major), _requires_grad(false)
+        _dimensions(dimensions), _offset(0), _strides(strides), _order(layout_type::column_major)
     {
         auto_strides();
 
@@ -312,7 +320,7 @@ namespace traph
     template<typename T>
     Tensor<T>::Tensor(const DimVector& dimensions, const DimVector& strides, layout_type order)
         :_rep(new TensorStorage<T>),
-        _dimensions(dimensions), _offset(0), _strides(strides), _order(order), _requires_grad(false)
+        _dimensions(dimensions), _offset(0), _strides(strides), _order(order)
     {
         auto_strides();
 
@@ -322,7 +330,7 @@ namespace traph
     template<typename T>
     Tensor<T>::Tensor(const T& t)
         :_rep(new TensorStorage<T>),
-        _dimensions(), _offset(0), strides(), _order(order), _requires_grad(false)
+        _dimensions(), _offset(0), strides(), _order(order)
     {
         _dimensions.resize(1);
         auto_strides();
@@ -373,6 +381,18 @@ namespace traph
         apply_impl(0, _offset, f);
     }
     template<typename T>
+    TensorInterfacePtr Tensor<T>::clone() const
+    {
+        std::shared_ptr<Tensor<T>> cloned_tensor(new Tensor<T>);
+        cloned_tensor->_rep = std::dynamic_pointer_cast<TensorStorage<T>>(_rep->clone());
+        cloned_tensor->_dimensions = _dimensions;
+        cloned_tensor->_offset = _offset;
+        cloned_tensor->_strides = _strides;
+        cloned_tensor->_order = _order;
+        
+        return cloned_tensor;
+    }
+    template<typename T>
     void Tensor<T>::cos_()
     {
         apply_([](T a)->T {return std::cos(a); });

+ 5 - 1
traph/source/nn/executor.cpp

@@ -14,10 +14,14 @@ namespace traph
             {
                 std::vector<VariableInterfacePtr> cur_inputs = (*it)->inputs();
 				std::vector<VariableInterface*> cur_raw_inputs;
+
+				std::set<VariableInterface*> sorted_cur_raw_inputs(cur_raw_inputs.begin(), cur_raw_inputs.end());
+				std::set<VariableInterface*> sorted_visited_nodes(visited_nodes.begin(), visited_nodes.end());
+
 				for (auto &each : cur_inputs)
 					cur_raw_inputs.push_back(each.get());
                 std::vector<VariableInterface*> cur_inputs_no_visited;
-                std::set_difference(cur_raw_inputs.begin(), cur_raw_inputs.end(), visited_nodes.begin(), visited_nodes.end(),
+                std::set_difference(sorted_cur_raw_inputs.begin(), sorted_cur_raw_inputs.end(), sorted_visited_nodes.begin(), sorted_visited_nodes.end(),
                         std::inserter(cur_inputs_no_visited, cur_inputs_no_visited.begin()));
                 if(cur_inputs_no_visited.empty())
                 {

+ 4 - 2
traph/source/test/main.cpp

@@ -38,9 +38,11 @@ int main()
 
 	auto a = traph::ones<traph::f32>({ 2,3 });
 	a->requires_grad_(true);
-	auto b = traph::sum<traph::f32>(a);
+	auto b = traph::ones<traph::f32>({ 2,3 });
+	auto c = traph::add<traph::f32>(a, b);
+	auto d = traph::sum<traph::f32>(c);
 
-	b->backward();
+	d->backward();
 
 	std::cout << a->grad()->to_string();