Преглед изворни кода

add mean and fix pow bp bug

JasonWang пре 6 година
родитељ
комит
094e33a2ff

+ 4 - 0
traph/include/traph/core/tensor.h

@@ -36,6 +36,8 @@ namespace traph
         virtual DataType dtype() const = 0;
         virtual std::shared_ptr<TensorInterface> inverse() const = 0;
         virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const = 0;
+        virtual std::shared_ptr<TensorInterface> mean() const = 0;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) = 0;
         virtual void neg_() = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;
@@ -89,7 +91,9 @@ namespace traph
         virtual std::shared_ptr<TensorInterface> inverse() const = 0;
         virtual T item() const = 0;
         virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const = 0;
+        virtual TensorInterfacePtr mean() const = 0;
         virtual void mul_(T value) = 0;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) = 0;
         virtual void neg_() = 0;
         virtual idx_type offset() const = 0;
 		virtual layout_type order() const = 0;

+ 21 - 0
traph/include/traph/core/type.h

@@ -53,6 +53,27 @@ namespace traph
         std::variant<u8, i8, i16, i32, i64, f32, f64> _scalar;
         DataType _dtype;
     public:
+		ScalarType(u8 v)
+			:_scalar(v) {}
+
+		ScalarType(i8 v)
+			:_scalar(v) {}
+
+		ScalarType(i16 v)
+			:_scalar(v) {}
+
+		ScalarType(i32 v)
+			:_scalar(v) {}
+
+		ScalarType(i64 v)
+			:_scalar(v) {}
+
+		ScalarType(f32 v)
+			:_scalar(v) {}
+
+		ScalarType(f64 v)
+			:_scalar(v) {}
+
         DataType dtype() const
         {
             return _dtype;

+ 2 - 0
traph/include/traph/nn/function.h

@@ -150,6 +150,8 @@ namespace traph
 
 	BINARY_OP(matmul, MatmulOp)
 
+	UNARY_OP(mean, MeanOp)
+
 	VariableInterfacePtr pow(VariableInterfacePtr input, float exp)
 	{
 		DimVector result_dim;

+ 2 - 2
traph/include/traph/nn/layers/linear.h

@@ -18,9 +18,9 @@ namespace traph
         {
             _in_features = in_features;
             _out_features = out_features;
-            _weight = zeros<f32>({out_features, in_features}, true);
+            _weight = randn<f32>({out_features, in_features}, true);
             if(bias)
-                _bias = zeros<f32>({out_features}, true);
+                _bias = randn<f32>({out_features}, true);
             
             register_parameter("weight", _weight);
             register_parameter("bias", _bias);

+ 1 - 2
traph/include/traph/nn/layers/loss.h

@@ -31,8 +31,7 @@ namespace traph
             }
             else if(_reduction == MSELossReduction::MEAN)
             {
-                // fixme: use mean if it impled
-                ret = sum(pow(sub(input, target), 2.f));
+				ret = mean(pow(sub(input, target), 2.f));
             }
             else
             {

+ 27 - 1
traph/include/traph/nn/module.h

@@ -15,9 +15,25 @@ namespace traph
     class Module
     {
     private:
+        std::string _name;
         std::vector<std::pair<std::string, std::shared_ptr<VariableInterface>>> _parameters;
-        std::vector<std::shared_ptr<Module>> _children;
+        std::vector<std::pair<std::string, std::shared_ptr<Module>>> _children;
     public:
+        void add_module(const std::string& name, std::shared_ptr<Module> module)
+        {
+            _children.push_back(std::make_pair(name, module));
+        }
+
+        std::vector<std::shared_ptr<Module>> modules()
+        {
+            std::vector<std::shared_ptr<Module>> result;
+
+            for (const auto &m : _children)
+                if(m.second)
+                    result.push_back(m.second);
+
+            return result;
+        }
 
         std::vector<std::pair<std::string, std::shared_ptr<VariableInterface>>> named_parameters(bool recurse)
         {
@@ -47,6 +63,16 @@ namespace traph
                 for (const auto &p : _parameters)
 					if(p.second)
 						result.push_back(p.second);
+
+                for (const auto &m : _children)
+                {
+                    if(m.second)
+                    {
+                        auto child_params = m.second->parameters(true);
+                        for(auto &child_param: child_params)
+                            result.push_back(child_param);
+                    }
+                }
             }
             else
             {

+ 29 - 0
traph/include/traph/nn/operation.h

@@ -88,6 +88,33 @@ namespace traph
 		}
 	};
 
+	class MeanOp : public OpBase
+	{
+	public:
+		virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
+		{
+			assert(inputs.size() == 1);
+
+			TensorInterfacePtr input = inputs[0];
+			TensorInterfacePtr result = input->mean();
+
+			context.save(input);
+
+			return result;
+		}
+
+		virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
+		{
+			auto saved_tensors = context.get_saved_tensors();
+			assert(saved_tensors.size() == 1);
+
+			auto flat_size = saved_tensors[0]->size().flat_size();
+			auto result = std::dynamic_pointer_cast<TensorBase<f32>>(output_grad->clone());
+			result->mul_(1.f/flat_size);
+			return { result };
+		}
+	};
+
 	class PowOp: public OpBase
 	{
 	private:
@@ -119,6 +146,8 @@ namespace traph
 			
 			//FIXME x^n = n*x^(n-1)
 			cloned_x->mul_(_exp);
+			cloned_x->mul_(output_grad);
+			
 			return { cloned_x };
 		}
 	};

+ 2 - 1
traph/include/traph/nn/variable.h

@@ -158,11 +158,12 @@ namespace traph
 			}
 		}
 
-		// TODO:retain_graph
+		// TODO:retain_graph, retain_all_grad
 		for (int i = static_cast<int>(sorted_node.size()) - 1; i >= 0; --i)
 		{
 			_grad_fn = nullptr;
 			_inputs.clear();
+			// _grad = std::shared_ptr<TensorBase<f32>>(nullptr);
 		}
 	}
 

+ 3 - 1
traph/include/traph/tensor/byte_tensor.h

@@ -67,7 +67,9 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual u8 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(u8 value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(u8 value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 3 - 1
traph/include/traph/tensor/char_tensor.h

@@ -65,7 +65,9 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i8 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(i8 value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(i8 value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 3 - 1
traph/include/traph/tensor/double_tensor.h

@@ -65,7 +65,9 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual f64 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(f64 value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(f64 value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 3 - 1
traph/include/traph/tensor/float_tensor.h

@@ -66,7 +66,9 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual f32 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(f32 value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(f32 value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 3 - 1
traph/include/traph/tensor/int_tensor.h

@@ -65,7 +65,9 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i32 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(i32 value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(i32 value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 3 - 1
traph/include/traph/tensor/long_tensor.h

@@ -65,7 +65,9 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i64 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(i64 value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(i64 value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
 		virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 3 - 1
traph/include/traph/tensor/short_tensor.h

@@ -65,7 +65,9 @@ namespace traph
 		virtual std::shared_ptr<TensorInterface> inverse() const override;
 		virtual i16 item() const override;
 		virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(i16 value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(i16 value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 3 - 1
traph/include/traph/tensor/tensor.h

@@ -69,7 +69,9 @@ namespace traph
         virtual std::shared_ptr<TensorInterface> inverse() const override;
         virtual T item() const override;
         virtual std::shared_ptr<TensorInterface> matmul(std::shared_ptr<TensorInterface> mat) const override;
-		virtual void mul_(T value) override;
+		virtual TensorInterfacePtr mean() const override;
+        virtual void mul_(T value) override;
+        virtual void mul_(std::shared_ptr<TensorInterface> other) override;
         virtual void neg_() override;
         virtual idx_type offset() const override;
 		virtual layout_type order() const override;

+ 0 - 1
traph/source/nn/CMakeLists.txt

@@ -8,7 +8,6 @@ SET(NN_LIST
 	${HEADER_PATH}/autograd.h
 	${HEADER_PATH}/variable.h
 	${SOURCE_PATH}/variable.cpp
-	${HEADER_PATH}/graph.h
 	${HEADER_PATH}/executor.h
 	${SOURCE_PATH}/executor.cpp
 	${HEADER_PATH}/function.h

+ 55 - 0
traph/source/tensor/byte_tensor.cpp

@@ -280,11 +280,66 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    TensorInterfacePtr Tensor<u8>::mean() const
+    {
+        DimVector d(1);
+        d[0] = 1;
+
+        TensorPtr<u8> result(new Tensor<u8>(d));
+        auto flat_size = _dimensions.flat_size();
+        result->_rep->data[0] = reduce_([](u8 a, u8 b)->u8 {return a + b; });
+        result->_rep->data[0] /= flat_size;
+        return std::dynamic_pointer_cast<TensorInterface>(result);
+    }
+
     void Tensor<u8>::mul_(u8 value)
     {
         apply_([value](u8 a)->u8 {return a*value; });
     }
 
+    void Tensor<u8>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        // check tensor other type
+        if(other->dtype() != DataType::BYTE)
+            throw std::runtime_error("expected type byte tensor");
+		// check broadcast.shape = this.shape
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		// ok, get lhs, rhs
+		Tensor<u8> * lhs = this;
+		Tensor<u8> * rhs = dynamic_cast<Tensor<u8> *>(other.get());
+		std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
+			[&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
+
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+                if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
+                {
+                    lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
+                }
+                else
+                {
+                    mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
+                }
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+			}
+		};
+
+		mul_impl(-1, -1, lhs->offset(), rhs->offset());
+    }
+
     void Tensor<u8>::neg_()
     {
         apply_([](u8 a)->u8 {return -a; });

+ 55 - 0
traph/source/tensor/char_tensor.cpp

@@ -280,11 +280,66 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    TensorInterfacePtr Tensor<i8>::mean() const
+    {
+        DimVector d(1);
+        d[0] = 1;
+
+        TensorPtr<i8> result(new Tensor<i8>(d));
+        auto flat_size = _dimensions.flat_size();
+        result->_rep->data[0] = reduce_([](i8 a, i8 b)->i8 {return a + b; });
+        result->_rep->data[0] /= flat_size;
+        return std::dynamic_pointer_cast<TensorInterface>(result);
+    }
+
     void Tensor<i8>::mul_(i8 value)
     {
         apply_([value](i8 a)->i8 {return a*value; });
     }
 
+    void Tensor<i8>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        // check tensor other type
+        if(other->dtype() != DataType::CHAR)
+            throw std::runtime_error("expected type char tensor");
+		// check broadcast.shape = this.shape
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		// ok, get lhs, rhs
+		Tensor<i8> * lhs = this;
+		Tensor<i8> * rhs = dynamic_cast<Tensor<i8> *>(other.get());
+		std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
+			[&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
+
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+                if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
+                {
+                    lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
+                }
+                else
+                {
+                    mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
+                }
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+			}
+		};
+
+		mul_impl(-1, -1, lhs->offset(), rhs->offset());
+    }
+
     void Tensor<i8>::neg_()
     {
         apply_([](i8 a)->i8 {return -a; });

+ 55 - 0
traph/source/tensor/double_tensor.cpp

@@ -281,11 +281,66 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    TensorInterfacePtr Tensor<f64>::mean() const
+    {
+        DimVector d(1);
+        d[0] = 1;
+
+        TensorPtr<f64> result(new Tensor<f64>(d));
+        auto flat_size = _dimensions.flat_size();
+        result->_rep->data[0] = reduce_([](f64 a, f64 b)->f64 {return a + b; });
+        result->_rep->data[0] /= flat_size;
+        return std::dynamic_pointer_cast<TensorInterface>(result);
+    }
+
     void Tensor<f64>::mul_(f64 value)
     {
         apply_([value](f64 a)->f64 {return a*value; });
     }
 
+    void Tensor<f64>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        // check tensor other type
+        if(other->dtype() != DataType::DOUBLE)
+            throw std::runtime_error("expected type double tensor");
+		// check broadcast.shape = this.shape
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		// ok, get lhs, rhs
+		Tensor<f64> * lhs = this;
+		Tensor<f64> * rhs = dynamic_cast<Tensor<f64> *>(other.get());
+		std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
+			[&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
+
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+                if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
+                {
+                    lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
+                }
+                else
+                {
+                    mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
+                }
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+			}
+		};
+
+		mul_impl(-1, -1, lhs->offset(), rhs->offset());
+    }
+
     void Tensor<f64>::neg_()
     {
         apply_([](f64 a)->f64 {return -a; });

+ 55 - 0
traph/source/tensor/float_tensor.cpp

@@ -282,11 +282,66 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    TensorInterfacePtr Tensor<f32>::mean() const
+    {
+        DimVector d(1);
+        d[0] = 1;
+
+        TensorPtr<f32> result(new Tensor<f32>(d));
+        auto flat_size = _dimensions.flat_size();
+        result->_rep->data[0] = reduce_([](f32 a, f32 b)->f32 {return a + b; });
+        result->_rep->data[0] /= flat_size;
+        return std::dynamic_pointer_cast<TensorInterface>(result);
+    }
+
     void Tensor<f32>::mul_(f32 value)
     {
         apply_([value](f32 a)->f32 {return a*value; });
     }
 
+    void Tensor<f32>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        // check tensor other type
+        if(other->dtype() != DataType::FLOAT)
+            throw std::runtime_error("expected type float tensor");
+		// check broadcast.shape = this.shape
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		// ok, get lhs, rhs
+		Tensor<f32> * lhs = this;
+		Tensor<f32> * rhs = dynamic_cast<Tensor<f32> *>(other.get());
+		std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
+			[&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
+
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+                if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
+                {
+                    lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
+                }
+                else
+                {
+                    mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
+                }
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+			}
+		};
+
+		mul_impl(-1, -1, lhs->offset(), rhs->offset());
+    }
+
 
     void Tensor<f32>::neg_()
     {

+ 54 - 0
traph/source/tensor/int_tensor.cpp

@@ -281,11 +281,65 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    TensorInterfacePtr Tensor<i32>::mean() const
+    {
+        DimVector d(1);
+        d[0] = 1;
+
+        TensorPtr<i32> result(new Tensor<i32>(d));
+        auto flat_size = _dimensions.flat_size();
+        result->_rep->data[0] = reduce_([](i32 a, i32 b)->i32 {return a + b; });
+        result->_rep->data[0] /= flat_size;
+        return std::dynamic_pointer_cast<TensorInterface>(result);
+    }
+
     void Tensor<i32>::mul_(i32 value)
     {
         apply_([value](i32 a)->i32 {return a*value; });
     }
 
+    void Tensor<i32>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        // check tensor other type
+        if(other->dtype() != DataType::INT)
+            throw std::runtime_error("expected type int tensor");
+		// check broadcast.shape = this.shape
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		// ok, get lhs, rhs
+		Tensor<i32> * lhs = this;
+		Tensor<i32> * rhs = dynamic_cast<Tensor<i32> *>(other.get());
+		std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
+			[&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
+
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+                if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
+                {
+                    lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
+                }
+                else
+                {
+                    mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
+                }
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+			}
+		};
+
+		mul_impl(-1, -1, lhs->offset(), rhs->offset());
+    }
 
     void Tensor<i32>::neg_()
     {

+ 55 - 0
traph/source/tensor/long_tensor.cpp

@@ -281,11 +281,66 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    TensorInterfacePtr Tensor<i64>::mean() const
+    {
+        DimVector d(1);
+        d[0] = 1;
+
+        TensorPtr<i64> result(new Tensor<i64>(d));
+        auto flat_size = _dimensions.flat_size();
+        result->_rep->data[0] = reduce_([](i64 a, i64 b)->i64 {return a + b; });
+        result->_rep->data[0] /= flat_size;
+        return std::dynamic_pointer_cast<TensorInterface>(result);
+    }
+
     void Tensor<i64>::mul_(i64 value)
     {
         apply_([value](i64 a)->i64 {return a*value; });
     }
 
+    void Tensor<i64>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        // check tensor other type
+        if(other->dtype() != DataType::LONG)
+            throw std::runtime_error("expected type long tensor");
+		// check broadcast.shape = this.shape
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		// ok, get lhs, rhs
+		Tensor<i64> * lhs = this;
+		Tensor<i64> * rhs = dynamic_cast<Tensor<i64> *>(other.get());
+		std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
+			[&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
+
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+                if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
+                {
+                    lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
+                }
+                else
+                {
+                    mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
+                }
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+			}
+		};
+
+		mul_impl(-1, -1, lhs->offset(), rhs->offset());
+    }
+
     void Tensor<i64>::neg_()
     {
         apply_([](i64 a)->i64 {return -a; });

+ 55 - 0
traph/source/tensor/short_tensor.cpp

@@ -281,11 +281,66 @@ namespace traph
 		return matmul_impl(*this, *right_matrix);
 	}
 
+    TensorInterfacePtr Tensor<i16>::mean() const
+    {
+        DimVector d(1);
+        d[0] = 1;
+
+        TensorPtr<i16> result(new Tensor<i16>(d));
+        auto flat_size = _dimensions.flat_size();
+        result->_rep->data[0] = reduce_([](i16 a, i16 b)->i16 {return a + b; });
+        result->_rep->data[0] /= flat_size;
+        return std::dynamic_pointer_cast<TensorInterface>(result);
+    }
+
     void Tensor<i16>::mul_(i16 value)
     {
         apply_([value](i16 a)->i16 {return a*value; });
     }
 
+    void Tensor<i16>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        // check tensor other type
+        if(other->dtype() != DataType::SHORT)
+            throw std::runtime_error("expected type short tensor");
+		// check broadcast.shape = this.shape
+        auto shape = broadcast_shape(this->size(), other->size());
+        if(shape != this->size())
+            throw std::runtime_error("The size of tensor a must match the size of tensor b");
+		// ok, get lhs, rhs
+		Tensor<i16> * lhs = this;
+		Tensor<i16> * rhs = dynamic_cast<Tensor<i16> *>(other.get());
+		std::function<void(idx_type, idx_type, idx_type, idx_type)> mul_impl =
+			[&](idx_type lhs_dim, idx_type rhs_dim, idx_type lhs_idx, idx_type rhs_idx) {
+
+			auto lhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(lhs->storage())->data_ptr();
+			auto rhs_storage = std::dynamic_pointer_cast<TensorStorage<f32>>(rhs->storage())->data_ptr();
+
+			idx_type lsh_shape_size = lhs_dim >= -(lhs->size().size())? lhs->size(lhs_dim) : 1;
+			idx_type rsh_shape_size = rhs_dim >= -(rhs->size().size()) ? rhs->size(rhs_dim) : 1;
+			idx_type max_shape_size = std::max(lsh_shape_size, rsh_shape_size);
+
+			for (idx_type i = 0; i < max_shape_size; ++i)
+			{
+                if (lhs_dim <= -(lhs->size().size()) && rhs_dim <= -(rhs->size().size()))
+                {
+                    lhs_storage[lhs_idx] *= rhs_storage[rhs_idx];
+                }
+                else
+                {
+                    mul_impl(lhs_dim - 1, rhs_dim - 1, lhs_idx, rhs_idx);
+                }
+
+				if(lsh_shape_size > 1)
+					lhs_idx += lhs->stride(lhs_dim);
+				if (rsh_shape_size > 1)
+					rhs_idx += rhs->stride(rhs_dim);
+			}
+		};
+
+		mul_impl(-1, -1, lhs->offset(), rhs->offset());
+    }
+
     void Tensor<i16>::neg_()
     {
         apply_([](i16 a)->i16 {return -a; });

+ 12 - 0
traph/source/tensor/tensor.cpp

@@ -121,12 +121,24 @@ namespace traph
 		throw std::runtime_error("No implement");
     }
 
+    template<typename T>
+    TensorInterfacePtr Tensor<T>::mean() const
+    {
+        throw std::runtime_error("No implement");
+    }
+
     template<typename T>
     void Tensor<T>::mul_(T value)
     {
         throw std::runtime_error("No implement");
     }
 
+    template<typename T>
+    void Tensor<T>::mul_(std::shared_ptr<TensorInterface> other)
+    {
+        throw std::runtime_error("No implement");
+    }
+
     template<typename T>
     void Tensor<T>::neg_()
     {

+ 33 - 7
traph/source/test/main.cpp

@@ -8,6 +8,31 @@
 
 #include <iostream>
 
+using namespace traph;
+
+class MyModel : public Module
+{
+private:
+	std::shared_ptr<Linear> linear1;
+	// std::shared_ptr<Linear> linear2;
+	// std::shared_ptr<Linear> linear3;
+public:
+
+	MyModel()
+		:linear1(new Linear(1024, 512, false))
+		// linear2(new Linear(512, 256, false)),
+		// linear3(new Linear(256, 128, false))
+	{
+		add_module("linear1", linear1);
+		// add_module("linear2", linear2);
+		// add_module("linear3", linear3);
+	}
+	std::shared_ptr<VariableInterface> forward(std::shared_ptr<VariableInterface> input)
+	{
+		return linear1->forward(input);
+	}
+};
+
 int main()
 {
 	/*
@@ -62,13 +87,13 @@ int main()
 
 	int batch_size = 16;
 	
-	auto x = traph::ones<traph::f32>({ batch_size,4 });
-	auto y = traph::ones<traph::f32>({ batch_size,2 });
+	auto x = traph::ones<traph::f32>({ batch_size,1024 });
+	auto y = traph::ones<traph::f32>({ batch_size,512 });
 
-	traph::Linear linear_model(4, 2, false);
-	traph::MSELoss criterion;
-	traph::SGD optimizer(linear_model.parameters(), 0.001f);
-	std::cout << y->data()->to_string() << std::endl;
+	MyModel model;
+	MSELoss criterion;
+	traph::SGD optimizer(model.parameters(), 0.01f);
+	// std::cout << y->data()->to_string() << std::endl;
 
 	std::cout << "Start Training..." << std::endl;
 
@@ -77,7 +102,7 @@ int main()
 		float loss100 = 0.f;
 
 		optimizer.zero_grad();
-		auto out = linear_model.forward(x);
+		auto out = model.forward(x);
 		auto loss = criterion.forward(out, y);
 		loss->backward();
 		optimizer.step();
@@ -86,6 +111,7 @@ int main()
 		std::cout << loss->data()->to_string() << std::endl;
 	}
 	
+
 	//auto a = traph::ones<traph::f32>({ 2,3 });
 	//a->requires_grad_(true);
 	//auto b = traph::ones<traph::f32>({ 3,4 });