JasonWang 6 anni fa
parent
commit
0454eaa6b2

+ 1 - 1
.vscode/settings.json

@@ -1,5 +1,5 @@
 {
-    "python.pythonPath": "C:\\Users\\wangjun\\AppData\\Local\\Programs\\Python\\Python37\\python.exe",
+    "python.pythonPath": "C:\\Users\\jstzw\\AppData\\Local\\Programs\\Python\\Python37\\python.exe",
     "files.associations": {
         "algorithm": "cpp",
         "array": "cpp",

+ 5 - 0
CMakeLists.txt

@@ -16,6 +16,11 @@ SET(TRAPH_PATH_HEADER ${TRAPH_PATH_INCLUDE}/traph CACHE STRING "Adds a path to T
 SET(TRAPH_PATH_SOURCE ${TRAPH_PATH}/traph/source CACHE STRING "Adds a path to TRAPH source" FORCE)
 SET(TRAPH_PATH_DEPENDENCIES ${TRAPH_PATH}/traph/contrib CACHE STRING "Adds a path to TRAPH dependencies" FORCE)
 
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_CXX_EXTENSIONS OFF)
+
+
 find_package(Boost)
 if(Boost_FOUND)
 	message(STATUS Boost found: ${Boost_INCLUDE_DIRS})

+ 18 - 0
python/pytraph/core/tensor.py

@@ -15,6 +15,24 @@ class Tensor(object):
         else:
             return "None"
 
+    def __getitem__(self, given):
+        slice_vector = pytraph.core.traph_tensor.SliceVector()
+        if isinstance(given, slice):
+            slice_vector.push_back(pytraph.core.traph_tensor.Slice(given.start, given.step, given.stop))
+        elif isinstance(given, tuple):
+            for each_slice in given:
+                if isinstance(given, slice):
+                    slice_vector.push_back(pytraph.core.traph_tensor.Slice(each_slice.start, each_slice.step, each_slice.stop))
+                else:
+                    slice_vector.push_back(pytraph.core.traph_tensor.Slice(each_slice, 1, each_slice+1))
+        else:
+            slice_vector.push_back(pytraph.core.traph_tensor.Slice(given, 1, given+1))
+
+        return self._inner_tensor.select(slice_vector)
+
+    def __setitem__(self,key,value):
+        self.dict[key] = value
+
 class FloatTensor(Tensor):
     def __init__(self):
         self._inner_tensor = pytraph.core.traph_tensor.FloatTensor()

+ 19 - 8
traph/include/traph/core/slice.h

@@ -2,6 +2,7 @@
 #define TRAPH_TENSOR_SLICE_H_
 
 #include <vector>
+#include <optional>
 
 #include <traph/core/type.h>
 #include <traph/core/index.h>
@@ -12,9 +13,9 @@ namespace traph
     class BasicSlice
     {
     public:
-        idx_type start;
-        idx_type step;
-        idx_type end;
+        std::optional<idx_type> start;
+		std::optional<idx_type> step;
+		std::optional<idx_type> end;
     };
 
     /*
@@ -34,12 +35,22 @@ namespace traph
     class Slice
     {
     public:
-        idx_type start;
-        idx_type step;
-        idx_type end;
+		std::optional<idx_type> start;
+		std::optional<idx_type> step;
+		std::optional<idx_type> end;
 
-		Slice(idx_type start, idx_type step, idx_type end)
-			:start(start), step(step), end(end)
+        Slice()
+			:start(), step(), end()
+		{
+		}
+
+        Slice(idx_type start, idx_type end)
+			:start(start), step(), end(end)
+		{
+		}
+
+		Slice(idx_type start, idx_type end, idx_type step)
+			:start(start), end(end), step(step)
 		{
 		}
     };

+ 13 - 4
traph/source/interface/traph_tensor.i

@@ -6,16 +6,25 @@
 #endif
 %}
 
+%include "std_vector.i"
+
+namespace std {
+  %template(IntVector) vector<int>;
+  %template(DoubleVector) vector<double>;
+};
+
 %{
     #include <string>
-    #include<traph/core/type.h>
-    #include<traph/core/index.h>
-    #include<traph/tensor/tensor.h>
-    #include<traph/tensor/tensor_storage.h>
+    #include <traph/core/type.h>
+    #include <traph/core/index.h>
+    #include <traph/core/slice.h>
+    #include <traph/tensor/tensor.h>
+    #include <traph/tensor/tensor_storage.h>
     
     using namespace traph;
 %}
 
+%template(SliceVector) std::vector<Slice>;
 
 typedef float f32;
 typedef double f64;