JasonWang 6 роки тому
батько
коміт
80ddef0f09

+ 24 - 0
traph/include/traph/core/operation.h

@@ -80,6 +80,30 @@ namespace traph
 		}
 	};
 
+	class SelectOp : public OpBase
+	{
+	public:
+		SliceVector slice;
+		void set_slice(const SliceVector& s)
+		{
+			slice = s;
+		}
+		virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
+		{
+			assert(inputs.size() == 1);
+
+			TensorInterfacePtr input = inputs[0];
+			
+			return input->select(slice);
+		}
+
+		virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
+		{
+			TensorBasePtr<f32> result = std::dynamic_pointer_cast<TensorBase<f32>>(output_grad->select(slice));
+			return { result };
+		}
+	};
+
 	class SinOp : public OpBase
 	{
 	public:

+ 5 - 0
traph/include/traph/core/slice.h

@@ -37,6 +37,11 @@ namespace traph
         idx_type start;
         idx_type step;
         idx_type end;
+
+		Slice(idx_type start, idx_type step, idx_type end)
+			:start(start), step(step), end(end)
+		{
+		}
     };
 
     using SliceVector = std::vector<Slice>;

+ 29 - 0
traph/include/traph/nn/arithmetic.h

@@ -66,6 +66,35 @@ namespace traph
 		return result;
 	}
 
+	
+	template<class T>
+	VariablePtr<T> select(VariablePtr<T> input, const SliceVector& slice)
+	{
+		VariablePtr<T> result(new Variable<T>);
+		std::shared_ptr<SelectOp> op(new SelectOp);
+		op->set_slice(slice);
+
+		std::vector<VariableInterfacePtr> result_inputs{ input };
+		result->_data = std::dynamic_pointer_cast<TensorBase<T>>(op->forward({ input->_data }));
+		result->_leaf = false;
+
+		if (input->requires_grad())
+		{
+			result->_grad = result->_data->create_grad();
+			result->_grad->fill_(0);
+			result->_requires_grad = true;
+			result->_grad_fn = op;
+			result->_inputs = result_inputs;
+		}
+		else
+		{
+			result->_requires_grad = false;
+		}
+
+		return result;
+	}
+
+
 	template<class T>
 	VariablePtr<T> sin(VariablePtr<T> input)
 	{

+ 5 - 2
traph/include/traph/nn/variable.h

@@ -52,6 +52,9 @@ namespace traph
 		template<class T>
 		friend std::shared_ptr<Variable<T>> add(std::shared_ptr<Variable<T>> left, std::shared_ptr<Variable<T>> right);
 
+		template<class T>
+		friend std::shared_ptr<Variable<T>> select(std::shared_ptr<Variable<T>> input, const SliceVector& slice);
+
 		template<class T>
 		friend std::shared_ptr<Variable<T>> sin(std::shared_ptr<Variable<T>> input);
 
@@ -310,7 +313,7 @@ namespace traph
 		for (auto i : l)
 			dim.push_back(i);
 
-        std::shared_ptr<Variable<T>> result(new Variable<T>(dim, true));
+        std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
         result->fill_(0);
 
         return result;
@@ -323,7 +326,7 @@ namespace traph
 		for (auto i : l)
 			dim.push_back(i);
 
-        std::shared_ptr<Variable<T>> result(new Variable<T>(dim, true));
+        std::shared_ptr<Variable<T>> result(new Variable<T>(dim, false));
         result->fill_(1);
 
         return result;

+ 0 - 3
traph/include/traph/tensor/tensor.h

@@ -405,9 +405,6 @@ namespace traph
     {
         std::shared_ptr<Tensor<T>> result(new Tensor<T>);
         result->_rep = _rep;
-        auto compute_offset = [](){
-
-        };
 
         // dimension
         DimVector dim;

+ 7 - 2
traph/source/test/main.cpp

@@ -35,7 +35,7 @@ int main()
 	std::cout << b;
 	*/
 	// auto a = traph::Variable<traph::f32>({ 2, 3 });
-
+/*
 	auto a = traph::ones<traph::f32>({ 2,3,2 });
 	a->requires_grad_(true);
 	auto b = traph::sin<traph::f32>(a);
@@ -47,6 +47,11 @@ int main()
 	e->backward();
 
 	std::cout << a->grad()->to_string();
-
+*/
+	auto a = traph::ones<traph::f32>({ 2,3 });
+	traph::SliceVector slice;
+	slice.push_back(traph::Slice(0, 1, 1));
+	slice.push_back(traph::Slice(0, 1, 2));
+	auto b = traph::select(a, slice);
     return 0;
 }