Selaa lähdekoodia

check shape before add

JasonWang 6 vuotta sitten
vanhempi
sitoutus
3e98967ee9

+ 25 - 1
traph/include/traph/core/index.h

@@ -9,7 +9,7 @@
 #include <utility>
 #include <traph/core/type.h>
 
-#define DIMVECTOR_SMALL_VECTOR_OPTIMIZATION 4
+#define DIMVECTOR_SMALL_VECTOR_OPTIMIZATION 5
 
 namespace traph
 {
@@ -93,6 +93,30 @@ namespace traph
             return *this;
         }
 
+        bool operator==(const DimVector& other) const
+        {
+            if(dim_num != other.dim_num)
+                return false;
+
+            for(idx_type i = 0; i < dim_num; ++i)
+                if(this->operator[](i) != other[i])
+                    return false;
+
+            return true;
+        }
+
+        bool operator!=(const DimVector& other) const
+        {
+            if(dim_num != other.dim_num)
+                return true;
+
+            for(idx_type i = 0; i < dim_num; ++i)
+                if(this->operator[](i) != other[i])
+                    return true;
+
+            return false;
+        }
+
         void erase(idx_type idx)
         {
             if(idx > 0 && idx < dim_num)

+ 8 - 14
traph/include/traph/core/tensor.h

@@ -128,37 +128,31 @@ namespace traph
     template<class T>
     using TensorBaseConstRef = const TensorBase<T>&;
 
-    template<class T>
-    bool broadcastable(const TensorBase<T> &lhs, const TensorBase<T> & rhs)
+    bool broadcastable(const DimVector &lhs, const DimVector & rhs)
     {
-        DimVector lhs_dim = lhs.size();
-        DimVector rhs_dim = rhs.size();
-        if(lhs_dim.size() < 1 || rhs_dim.size() < 1)
+        if(lhs.size() < 1 || rhs.size() < 1)
             return false;
 
-        idx_type min = std::min(lhs_dim.size(), rhs_dim.size());
+        idx_type min = std::min(lhs.size(), rhs.size());
 		for (idx_type i = -1; i >= -min; --i)
-			if (lhs.size(i) != rhs.size(i) && lhs.size(i) != 1 && rhs.size(i) != 1)
+			if (lhs[i] != rhs[i] && lhs[i] != 1 && rhs[i] != 1)
 				return false;    
 
         return true;
     }
 
-	template<class T>
-	DimVector broadcast_shape(const TensorBase<T> &lhs, const TensorBase<T> & rhs)
+	DimVector broadcast_shape(const DimVector &lhs, const DimVector & rhs)
 	{
 		bool is_broadcastable = broadcastable(lhs, rhs);
 		if (!is_broadcastable)
 			throw std::runtime_error("The size of tensor a must match the size of tensor b");
-		DimVector lhs_dim = lhs.size();
-		DimVector rhs_dim = rhs.size();
-		auto max_size = std::max(lhs_dim.size(), rhs_dim.size());
+		auto max_size = std::max(lhs.size(), rhs.size());
 		DimVector result_dim(max_size);
 
 		for (idx_type i = -1; i >= -max_size; --i)
 		{
-			idx_type lhs_size = i >= -lhs_dim.size() ? lhs.size(i) : 1;
-			idx_type rhs_size = i >= -rhs_dim.size() ? rhs.size(i) : 1;
+			idx_type lhs_size = i >= -lhs.size() ? lhs[i] : 1;
+			idx_type rhs_size = i >= -rhs.size() ? rhs[i] : 1;
 			result_dim[max_size + i] = std::max(lhs_size, rhs_size);
 		}
 		return result_dim;

+ 57 - 0
traph/include/traph/core/type.h

@@ -2,6 +2,7 @@
 #define TRAPH_CORE_TYPE_H_
 
 #include <variant>
+#include <optional>
 #include <cstdint>
 
 namespace traph
@@ -56,6 +57,62 @@ namespace traph
         {
             return _dtype;
         }
+
+        std::optional<u8> get_byte()
+        {
+            if(_dtype == DataType::BYTE)
+                return std::get<u8>(_scalar);
+            else
+                return std::nullopt;
+        }
+
+        std::optional<i8> get_char()
+        {
+            if(_dtype == DataType::CHAR)
+                return std::get<i8>(_scalar);
+            else
+                return std::nullopt;
+        }
+
+        std::optional<i16> get_short()
+        {
+            if(_dtype == DataType::SHORT)
+                return std::get<i16>(_scalar);
+            else
+                return std::nullopt;
+        }
+
+        std::optional<i32> get_int()
+        {
+            if(_dtype == DataType::INT)
+                return std::get<i32>(_scalar);
+            else
+                return std::nullopt;
+        }
+
+        std::optional<i64> get_long()
+        {
+            if(_dtype == DataType::LONG)
+                return std::get<i64>(_scalar);
+            else
+                return std::nullopt;
+        }
+
+        std::optional<f32> get_float()
+        {
+            if(_dtype == DataType::FLOAT)
+                return std::get<f32>(_scalar);
+            else
+                return std::nullopt;
+        }
+
+        std::optional<f64> get_double()
+        {
+            if(_dtype == DataType::DOUBLE)
+                return std::get<f64>(_scalar);
+            else
+                return std::nullopt;
+        }
     };
 }
 

+ 42 - 21
traph/include/traph/nn/function.h

@@ -2,6 +2,7 @@
 #define TRAPH_NN_FUNCTION_H_
 
 #include <utility>
+#include <random>
 #include <cmath>
 
 #include <traph/core/type.h>
@@ -16,27 +17,27 @@
 namespace traph
 {
 
-#define UNARY_OP(name, op_name)                                           \
-	VariableInterfacePtr name(VariableInterfacePtr input)                 \
-	{                                                                     \
-		DimVector result_dim;                                             \
-        VariableInterfacePtr result = input->new_empty(result_dim, true); \
-		std::shared_ptr<op_name> op(new op_name);                         \
-		std::vector<VariableInterfacePtr> result_inputs{ input };         \
-		result->data_(op->forward({ input->data() }));                    \
-		if (input->requires_grad())                                       \
-		{                                                                 \
-			result->grad_(result->data()->create_grad());                 \
-			result->grad()->fill_(0);                                     \
-			result->requires_grad_(true);                                 \
-			result->grad_fn_(op);                                         \
-			result->inputs_(result_inputs);                               \
-		}                                                                 \
-		else                                                              \
-		{                                                                 \
-			result->requires_grad_(false);                                \
-		}                                                                 \
-		return result;                                                    \
+#define UNARY_OP(name, op_name)                                                            \
+	VariableInterfacePtr name(VariableInterfacePtr input)                                  \
+	{                                                                                      \
+		DimVector result_dim;                                                              \
+        VariableInterfacePtr result = input->new_empty(result_dim, true);                  \
+		std::shared_ptr<op_name> op(new op_name);                                          \
+		std::vector<VariableInterfacePtr> result_inputs{ input };                          \
+		result->data_(op->forward({ input->data() }));                                     \
+		if (input->requires_grad())                                                        \
+		{                                                                                  \
+			result->grad_(result->data()->create_grad());                                  \
+			result->grad()->fill_(0);                                                      \
+			result->requires_grad_(true);                                                  \
+			result->grad_fn_(op);                                                          \
+			result->inputs_(result_inputs);                                                \
+		}                                                                                  \
+		else                                                                               \
+		{                                                                                  \
+			result->requires_grad_(false);                                                 \
+		}                                                                                  \
+		return result;                                                                     \
 	}
 
 #define BINARY_OP(name, op_name)                                                           \
@@ -102,6 +103,26 @@ namespace traph
 		return result;
 	}
 
+	template<typename T>
+	VariableInterfacePtr randn(std::initializer_list<idx_type> l, bool requires_grad = false)
+	{
+		DimVector dim;
+		for (auto i : l)
+			dim.push_back(i);
+
+		std::random_device rd{};
+		std::mt19937 gen{rd()};
+		std::normal_distribution<> d{0,1};
+
+		std::shared_ptr<VariableInterface> result(new Variable<T>(dim));
+		std::shared_ptr<TensorBase<T>> result_data = std::dynamic_pointer_cast<TensorBase<T>>(result->data());
+		result_data->apply_([&d, &gen](T n){
+			return d(gen);
+		});
+
+		return result;
+	}
+
 	template<typename T>
 	VariableInterfacePtr empty_like(VariableInterfacePtr input, bool requires_grad = false)
 	{

+ 5 - 2
traph/source/tensor/byte_tensor.cpp

@@ -154,9 +154,12 @@ namespace traph
     void Tensor<u8>::add_(TensorInterfacePtr other)
     {
 		// check tensor other type
-
+        if(other->dtype() != DataType::BYTE)
+            throw std::runtime_error("expected type byte tensor");
 		// check broadcast.shape = this.shape
-
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
 		// ok, get lhs, rhs
 		Tensor<u8> * lhs = this;
 		Tensor<u8> * rhs = dynamic_cast<Tensor<u8> *>(other.get());

+ 5 - 2
traph/source/tensor/char_tensor.cpp

@@ -154,9 +154,12 @@ namespace traph
     void Tensor<i8>::add_(TensorInterfacePtr other)
     {
 		// check tensor other type
-
+        if(other->dtype() != DataType::CHAR)
+            throw std::runtime_error("expected type char tensor");
 		// check broadcast.shape = this.shape
-
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
 		// ok, get lhs, rhs
 		Tensor<i8> * lhs = this;
 		Tensor<i8> * rhs = dynamic_cast<Tensor<i8> *>(other.get());

+ 5 - 2
traph/source/tensor/double_tensor.cpp

@@ -154,9 +154,12 @@ namespace traph
     void Tensor<f64>::add_(TensorInterfacePtr other)
     {
 		// check tensor other type
-
+        if(other->dtype() != DataType::DOUBLE)
+            throw std::runtime_error("expected type double tensor");
 		// check broadcast.shape = this.shape
-
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
 		// ok, get lhs, rhs
 		Tensor<f64> * lhs = this;
 		Tensor<f64> * rhs = dynamic_cast<Tensor<f64> *>(other.get());

+ 5 - 2
traph/source/tensor/float_tensor.cpp

@@ -154,9 +154,12 @@ namespace traph
     void Tensor<f32>::add_(TensorInterfacePtr other)
     {
 		// check tensor other type
-
+        if(other->dtype() != DataType::FLOAT)
+            throw std::runtime_error("expected type float tensor");
 		// check broadcast.shape = this.shape
-
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
 		// ok, get lhs, rhs
 		Tensor<f32> * lhs = this;
 		Tensor<f32> * rhs = dynamic_cast<Tensor<f32> *>(other.get());

+ 5 - 2
traph/source/tensor/int_tensor.cpp

@@ -154,9 +154,12 @@ namespace traph
     void Tensor<i32>::add_(TensorInterfacePtr other)
     {
 		// check tensor other type
-
+        if(other->dtype() != DataType::INT)
+            throw std::runtime_error("expected type int tensor");
 		// check broadcast.shape = this.shape
-
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
 		// ok, get lhs, rhs
 		Tensor<i32> * lhs = this;
 		Tensor<i32> * rhs = dynamic_cast<Tensor<i32> *>(other.get());

+ 5 - 2
traph/source/tensor/long_tensor.cpp

@@ -154,9 +154,12 @@ namespace traph
     void Tensor<i64>::add_(TensorInterfacePtr other)
     {
 		// check tensor other type
-
+        if(other->dtype() != DataType::LONG)
+            throw std::runtime_error("expected type long tensor");
 		// check broadcast.shape = this.shape
-
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
 		// ok, get lhs, rhs
 		Tensor<i64> * lhs = this;
 		Tensor<i64> * rhs = dynamic_cast<Tensor<i64> *>(other.get());

+ 5 - 2
traph/source/tensor/short_tensor.cpp

@@ -154,9 +154,12 @@ namespace traph
     void Tensor<i16>::add_(TensorInterfacePtr other)
     {
 		// check tensor other type
-
+        if(other->dtype() != DataType::SHORT)
+            throw std::runtime_error("expected type short tensor");
 		// check broadcast.shape = this.shape
-
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
 		// ok, get lhs, rhs
 		Tensor<i16> * lhs = this;
 		Tensor<i16> * rhs = dynamic_cast<Tensor<i16> *>(other.get());

+ 4 - 3
traph/source/test/main.cpp

@@ -58,8 +58,8 @@ int main()
 	*/
 
 	int batch_size = 16;
-	auto x = traph::ones<traph::f32>({ batch_size,4 });
-	auto y = traph::zeros<traph::f32>({ batch_size,2 });
+	auto x = traph::randn<traph::f32>({ batch_size,4 });
+	auto y = traph::randn<traph::f32>({ batch_size,2 });
 
 	traph::Linear linear_model(4, 2, false);
 	traph::MSELoss loss;
@@ -68,7 +68,8 @@ int main()
 	auto result = loss.forward(out, y);
 
 	result->backward();
-	std::cout << linear_model.parameters(true)[0]->grad()->to_string();
+	std::cout << x->data()->to_string() << std::endl;
+	std::cout << linear_model.parameters(true)[0]->grad()->to_string() << std::endl;
 
     return 0;
 }