jstzwj 6 éve
szülő
commit
0c949dec70

+ 11 - 2
traph/include/traph/core/operation.h

@@ -93,14 +93,23 @@ namespace traph
 			assert(inputs.size() == 1);
 
 			TensorInterfacePtr input = inputs[0];
+			auto grad = input->create_grad();
+			grad->fill_(0);
+			auto zero_grad = std::dynamic_pointer_cast<TensorInterface>(grad);
+
+			context.save(zero_grad);
 			
 			return input->select(slice);
 		}
 
 		virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
 		{
-			TensorBasePtr<f32> result = std::dynamic_pointer_cast<TensorBase<f32>>(output_grad->select(slice));
-			return { result };
+			auto saved_tensors = context.get_saved_tensors();
+			assert(saved_tensors.size() == 1);
+			auto grad = std::dynamic_pointer_cast<TensorBase<f32>>(saved_tensors[0]);
+			auto selected_grad = std::dynamic_pointer_cast<TensorBase<f32>>(grad->select(slice));
+			selected_grad->add_(output_grad);
+			return { grad };
 		}
 	};
 

+ 7 - 4
traph/include/traph/tensor/tensor.h

@@ -408,9 +408,12 @@ namespace traph
 
         // dimension
         DimVector dim;
-        for(auto& each:slice)
+        for(idx_type i = 0; i<slice.size(); ++i)
         {
-            dim.push_back(std::ceil((each.end - each.start)/(float)each.step));
+			auto& each = slice[i];
+            dim.push_back(
+				std::ceil((each.end.value_or(_dimensions[i]) - each.start.value_or(0))/(float)each.step.value_or(1))
+			);
         }
         result->_dimensions = dim;
 
@@ -418,7 +421,7 @@ namespace traph
         idx_type new_offset =1;
         for(idx_type i = 0; i < slice.size(); ++i)
         {
-            new_offset *= _strides[i] * slice[i].start;
+            new_offset *= _strides[i] * slice[i].start.value_or(0);
         }
         result->_offset = _offset + new_offset;
 
@@ -426,7 +429,7 @@ namespace traph
         DimVector strides;
         for(idx_type i = 0; i < slice.size(); ++i)
         {
-            strides.push_back(_strides[i] * slice[i].step);
+            strides.push_back(_strides[i] * slice[i].step.value_or(1));
         }
         result->_strides = strides;
 

+ 1 - 0
traph/source/test/main.cpp

@@ -53,5 +53,6 @@ int main()
 	slice.push_back(traph::Slice(0, 1, 1));
 	slice.push_back(traph::Slice(0, 1, 2));
 	auto b = traph::select(a, slice);
+	std::cout << b->data()->to_string();
     return 0;
 }