Explorar el Código

add dtype method

JasonWang hace 6 años
padre
commit
af6b54b5d3

+ 5 - 0
.vscode/settings.json

@@ -80,5 +80,10 @@
         "future": "cpp",
         "queue": "cpp",
         "cfenv": "cpp"
+    },
+    "files.exclude": {
+        "**/.git": true,
+        "**/build": true,
+        "**/traph/contrib": true,
     }
 }

+ 2 - 0
traph/include/traph/core/tensor.h

@@ -33,6 +33,7 @@ namespace traph
         virtual void cos_() = 0;
         virtual std::shared_ptr<TensorBase<f32>> create_grad() = 0;
         virtual device_id device() = 0;
+        virtual DataType dtype() const = 0;
         virtual std::shared_ptr<TensorInterface> inverse() const = 0;
         virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const = 0;
         virtual void neg_() = 0;
@@ -83,6 +84,7 @@ namespace traph
         virtual T* data_ptr() = 0;
         virtual const T* data_ptr() const = 0;
         virtual device_id device() = 0;
+        virtual DataType dtype() const = 0;
         virtual void fill_(T value) = 0;
         virtual std::shared_ptr<TensorInterface> inverse() const = 0;
         virtual T item() const = 0;

+ 1 - 1
traph/include/traph/core/type.h

@@ -34,7 +34,7 @@ namespace traph
         opengl
     };
 
-    enum dtype
+    enum DataType
     {
         BYTE,
         CHAR,

+ 0 - 2
traph/include/traph/core/variable.h

@@ -29,7 +29,6 @@ namespace traph
         virtual std::vector<VariableInterfacePtr>& inputs() = 0;
         virtual void inputs_(const std::vector<VariableInterfacePtr>& i) = 0;
         virtual bool is_leaf() const = 0;
-		virtual void leaf_(bool state) = 0;
         virtual std::shared_ptr<VariableInterface> new_empty(const DimVector& size, bool requires_grad) const = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
@@ -75,7 +74,6 @@ namespace traph
         virtual void inputs_(const std::vector<VariableInterfacePtr>& i) = 0;
         virtual bool is_leaf() const = 0;
         virtual T item() const = 0;
-		virtual void leaf_(bool state) = 0;
         virtual std::shared_ptr<VariableInterface> new_empty(const DimVector& size, bool requires_grad) const = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;

+ 4 - 18
traph/include/traph/nn/function.h

@@ -24,7 +24,6 @@ namespace traph
 			dim.push_back(i);
 
 		std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
-		result->leaf_(true);
 
 		return result;
 	}
@@ -36,8 +35,7 @@ namespace traph
 		for (auto i : l)
 			dim.push_back(i);
 
-		std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
-		result->leaf_(true);
+		std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
 		std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(0);
 
 		return result;
@@ -50,8 +48,7 @@ namespace traph
 		for (auto i : l)
 			dim.push_back(i);
 
-		std::shared_ptr<VariableInterface> result(new Variable<T>(dim, false));
-		result->leaf_(true);
+		std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
 		std::dynamic_pointer_cast<TensorBase<T>>(result->data())->fill_(1);
 
 		return result;
@@ -61,7 +58,6 @@ namespace traph
 	VariableInterfacePtr empty_like(VariableInterfacePtr input, bool requires_grad = false)
 	{
 		std::shared_ptr<VariableInterface> result(new Variable<T>(input->size(), false));
-		result->leaf_(true);
 
 		return result;
 	}
@@ -74,22 +70,20 @@ namespace traph
 
         VariableInterfacePtr result = input->new_empty(result_dim, true);
         std::shared_ptr<SumOp> op(new SumOp);
+
+		result->data_(op->forward({ input->data() }));
         if(input->requires_grad())
         {
 			std::vector<VariableInterfacePtr> result_inputs { input };
-            result->data_(op->forward({ input->data() }));
 			result->grad_(result->data()->create_grad());
 			result->grad()->fill_(0);
             result->requires_grad_(true);
-            result->leaf_(false);
             result->grad_fn_(op);
             result->inputs_(result_inputs);
         }
         else
         {
-            result->data_(op->forward({ input->data() }));
             result->requires_grad_(false);
-            result->leaf_(false);
         }
 
         return result;
@@ -108,7 +102,6 @@ namespace traph
 			result->grad_(result->data()->create_grad());
 			result->grad()->fill_(0);
 			result->requires_grad_(true);
-			result->leaf_(false);
 			result->grad_fn_(op);
 			result->inputs_(result_inputs);
 		}
@@ -116,7 +109,6 @@ namespace traph
 		{
 			result->data_(op->forward({ left->data(), right->data() }));
 			result->requires_grad_(false);
-			result->leaf_(false);
 		}
 
 		return result;
@@ -135,7 +127,6 @@ namespace traph
 			result->grad_(result->data()->create_grad());
 			result->grad()->fill_(0);
 			result->requires_grad_(true);
-			result->leaf_(false);
 			result->grad_fn_(op);
 			result->inputs_(result_inputs);
 		}
@@ -143,7 +134,6 @@ namespace traph
 		{
 			result->data_(op->forward({ left->data(), right->data() }));
 			result->requires_grad_(false);
-			result->leaf_(false);
 		}
 
 		return result;
@@ -160,7 +150,6 @@ namespace traph
 
 		std::vector<VariableInterfacePtr> result_inputs{ input };
 		result->data_(op->forward({ input->data() }));
-		result->leaf_(false);
 
 		if (input->requires_grad())
 		{
@@ -188,7 +177,6 @@ namespace traph
 
 		std::vector<VariableInterfacePtr> result_inputs{ input };
 		result->data_(op->forward({ input->data() }));
-		result->leaf_(false);
 
 		if (input->requires_grad())
 		{
@@ -213,7 +201,6 @@ namespace traph
         VariableInterfacePtr result = left->new_empty(result_dim, true);
 		std::shared_ptr<SubOp> op(new SubOp);
 		result->data_(op->forward({ left->data(), right->data() }));
-		result->leaf_(false);
 		if (left->requires_grad() || right->requires_grad())
 		{
 			std::vector<VariableInterfacePtr> result_inputs{ left, right };
@@ -241,7 +228,6 @@ namespace traph
 
 		std::vector<VariableInterfacePtr> result_inputs{ input };
 		result->data_(op->forward({ input->data() }));
-		result->leaf_(false);
 
 		if (input->requires_grad())
 		{

+ 4 - 28
traph/include/traph/nn/variable.h

@@ -28,7 +28,6 @@ namespace traph
         std::shared_ptr<TensorBase<T>> _data;
         std::shared_ptr<TensorBase<f32>> _grad;
         bool _requires_grad;
-        bool _leaf;
         std::shared_ptr<OpBase> _grad_fn;
         std::vector<VariableInterfacePtr> _inputs;
         // std::vector<std::weak_ptr<VariableInterface>> _outputs;
@@ -36,7 +35,6 @@ namespace traph
         Variable();
         Variable(std::shared_ptr<TensorBase<T>> data);
         Variable(const DimVector& dim);
-        Variable(const DimVector& dim, bool is_leaf);
         Variable(std::initializer_list<idx_type> l);
 
 		Variable(const Variable& other) = delete;
@@ -74,7 +72,6 @@ namespace traph
 		virtual void inputs_(const std::vector<VariableInterfacePtr>& i) override;
         virtual bool is_leaf() const override;
         virtual T item() const override;
-		virtual void leaf_(bool state) override;
 		virtual std::shared_ptr<VariableInterface> new_empty(const DimVector& size, bool requires_grad) const override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;
@@ -116,7 +113,7 @@ namespace traph
 	template<typename T>
 	Variable<T>::Variable()
 		:_data(new Tensor<T>), _grad(nullptr),
-		_requires_grad(false), _leaf(false),
+		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 
@@ -125,7 +122,7 @@ namespace traph
 	template<typename T>
 	Variable<T>::Variable(std::shared_ptr<TensorBase<T>> data)
 		:_data(data), _grad(nullptr),
-		_requires_grad(false), _leaf(false),
+		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 	}
@@ -133,30 +130,15 @@ namespace traph
 	template<typename T>
 	Variable<T>::Variable(const DimVector& dim)
 		:_data(new Tensor<T>(dim)), _grad(nullptr),
-		_requires_grad(false), _leaf(false),
+		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 	}
 
-	template<typename T>
-	Variable<T>::Variable(const DimVector& dim, bool is_leaf)
-		:_data(new Tensor<T>(dim)), _grad(nullptr),
-		_requires_grad(false), _leaf(is_leaf),
-		_grad_fn(nullptr), _inputs()
-	{
-		if (is_leaf)
-		{
-			_requires_grad = true;
-
-			_grad = _data->create_grad();
-			_grad->fill_(0);
-		}
-	}
-
 	template<typename T>
 	Variable<T>::Variable(std::initializer_list<idx_type> l)
 		:_data(new Tensor<T>()), _grad(nullptr),
-		_requires_grad(false), _leaf(false),
+		_requires_grad(false),
 		_grad_fn(nullptr), _inputs()
 	{
 		DimVector dim;
@@ -270,12 +252,6 @@ namespace traph
 		return _data->item();
 	}
 
-	template<typename T>
-	void Variable<T>::leaf_(bool state)
-	{
-		_leaf = state;
-	}
-
 	template<typename T>
 	std::shared_ptr<VariableInterface> Variable<T>::new_empty(const DimVector& size, bool requires_grad) const
 	{

+ 3 - 0
traph/include/traph/tensor/byte_tensor.h

@@ -4,6 +4,8 @@
 #include <utility>
 #include <cmath>
 
+
+#include <traph/core/type.h>
 #include <traph/tensor/tensor.h>
 
 namespace traph
@@ -62,6 +64,7 @@ namespace traph
 		virtual u8* data_ptr() override;
 		virtual const u8* data_ptr() const override;
 		virtual device_id device() override;
+        virtual DataType dtype() const override;
 		virtual void fill_(u8 value) override;
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual u8 item() const override;

+ 1 - 0
traph/include/traph/tensor/char_tensor.h

@@ -62,6 +62,7 @@ namespace traph
 		virtual i8* data_ptr() override;
 		virtual const i8* data_ptr() const override;
 		virtual device_id device() override;
+        virtual DataType dtype() const override;
 		virtual void fill_(i8 value) override;
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i8 item() const override;

+ 1 - 0
traph/include/traph/tensor/double_tensor.h

@@ -62,6 +62,7 @@ namespace traph
 		virtual f64* data_ptr() override;
 		virtual const f64* data_ptr() const override;
 		virtual device_id device() override;
+        virtual DataType dtype() const override;
 		virtual void fill_(f64 value) override;
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual f64 item() const override;

+ 1 - 0
traph/include/traph/tensor/float_tensor.h

@@ -63,6 +63,7 @@ namespace traph
 		virtual f32* data_ptr() override;
 		virtual const f32* data_ptr() const override;
 		virtual device_id device() override;
+        virtual DataType dtype() const override;
 		virtual void fill_(f32 value) override;
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual f32 item() const override;

+ 1 - 0
traph/include/traph/tensor/int_tensor.h

@@ -62,6 +62,7 @@ namespace traph
 		virtual i32* data_ptr() override;
 		virtual const i32* data_ptr() const override;
 		virtual device_id device() override;
+        virtual DataType dtype() const override;
 		virtual void fill_(i32 value) override;
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i32 item() const override;

+ 1 - 0
traph/include/traph/tensor/long_tensor.h

@@ -62,6 +62,7 @@ namespace traph
 		virtual i64* data_ptr() override;
 		virtual const i64* data_ptr() const override;
 		virtual device_id device() override;
+        virtual DataType dtype() const override;
 		virtual void fill_(i64 value) override;
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i64 item() const override;

+ 1 - 0
traph/include/traph/tensor/short_tensor.h

@@ -62,6 +62,7 @@ namespace traph
 		virtual i16* data_ptr() override;
 		virtual const i16* data_ptr() const override;
 		virtual device_id device() override;
+        virtual DataType dtype() const override;
 		virtual void fill_(i16 value) override;
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i16 item() const override;

+ 1 - 0
traph/include/traph/tensor/tensor.h

@@ -64,6 +64,7 @@ namespace traph
         virtual T* data_ptr() override;
         virtual const T* data_ptr() const override;
         virtual device_id device() override;
+        virtual DataType dtype() const override;
         virtual void fill_(T value) override;
         virtual std::shared_ptr<TensorInterface> inverse() const override;
         virtual T item() const override;

+ 5 - 0
traph/source/tensor/byte_tensor.cpp

@@ -229,6 +229,11 @@ namespace traph
 
     device_id Tensor<u8>::device() { return 0; }
 
+    DataType Tensor<u8>::dtype() const
+    {
+        return DataType::BYTE;
+    }
+
 	std::shared_ptr<TensorInterface> Tensor<u8>::inverse() const
 	{
 		throw std::runtime_error("No implement");

+ 5 - 0
traph/source/tensor/char_tensor.cpp

@@ -229,6 +229,11 @@ namespace traph
 
     device_id Tensor<i8>::device() { return 0; }
 
+    DataType Tensor<i8>::dtype() const
+    {
+        return DataType::CHAR;
+    }
+
 	std::shared_ptr<TensorInterface> Tensor<i8>::inverse() const
 	{
 		throw std::runtime_error("No implement");

+ 5 - 0
traph/source/tensor/double_tensor.cpp

@@ -229,6 +229,11 @@ namespace traph
 
     device_id Tensor<f64>::device() { return 0; }
 
+    DataType Tensor<f64>::dtype() const
+    {
+        return DataType::DOUBLE;
+    }
+
 	std::shared_ptr<TensorInterface> Tensor<f64>::inverse() const
 	{
 		return std::dynamic_pointer_cast<TensorInterface>(inverse_impl(*this));

+ 5 - 0
traph/source/tensor/float_tensor.cpp

@@ -230,6 +230,11 @@ namespace traph
 
     device_id Tensor<f32>::device() { return 0; }
 
+    DataType Tensor<f32>::dtype() const
+    {
+        return DataType::FLOAT;
+    }
+
 	std::shared_ptr<TensorInterface> Tensor<f32>::inverse() const
 	{
 		return std::dynamic_pointer_cast<TensorInterface>(inverse_impl(*this));

+ 5 - 0
traph/source/tensor/int_tensor.cpp

@@ -229,6 +229,11 @@ namespace traph
 
     device_id Tensor<i32>::device() { return 0; }
 
+    DataType Tensor<i32>::dtype() const
+    {
+        return DataType::INT;
+    }
+
 	std::shared_ptr<TensorInterface> Tensor<i32>::inverse() const
 	{
 		throw std::runtime_error("No implement");

+ 5 - 0
traph/source/tensor/long_tensor.cpp

@@ -229,6 +229,11 @@ namespace traph
 
     device_id Tensor<i64>::device() { return 0; }
 
+    DataType Tensor<i64>::dtype() const
+    {
+        return DataType::LONG;
+    }
+
 	std::shared_ptr<TensorInterface> Tensor<i64>::inverse() const
 	{
 		throw std::runtime_error("No implement");

+ 5 - 0
traph/source/tensor/short_tensor.cpp

@@ -229,6 +229,11 @@ namespace traph
 
     device_id Tensor<i16>::device() { return 0; }
 
+    DataType Tensor<i16>::dtype() const
+    {
+        return DataType::SHORT;
+    }
+
 	std::shared_ptr<TensorInterface> Tensor<i16>::inverse() const
 	{
 		throw std::runtime_error("No implement");

+ 6 - 0
traph/source/tensor/tensor.cpp

@@ -91,6 +91,12 @@ namespace traph
     }
     template<typename T>
     device_id Tensor<T>::device() { throw std::runtime_error("No implement"); }
+
+    template<typename T>
+    DataType Tensor<T>::dtype() const
+    {
+        throw std::runtime_error("No implement");
+    }
     template<typename T>
     void Tensor<T>::fill_(T value)
     {