浏览代码

add idx_type support

JasonWang 7 年之前
父节点
当前提交
0a095af46c
共有 2 个文件被更改,包括 64 次插入34 次删除
  1. 11 34
      traph/include/traph/core/tensor.h
  2. 53 0
      traph/source/interface/traph_tensor.i

+ 11 - 34
traph/include/traph/core/tensor.h

@@ -131,46 +131,23 @@ namespace traph
             auto_strides();
         }
 
-        Tensor(nested_initializer_list_t<T, 1> t)
-            :rep(new TensorStorage<T>),
-            dimensions(), offset(0), strides(), order(order)
-        {
-            dimensions.resize(1);
-            auto_strides();
-        }
-        Tensor(nested_initializer_list_t<T, 2> t)
-            :rep(new TensorStorage<T>),
-            dimensions(), offset(0), strides(), order(order)
+        void reshape(const DimVector& dims)
         {
-            dimensions.resize(2);
-            auto_strides();
-        }
 
-        Tensor(nested_initializer_list_t<T, 3> t)
-            :rep(new TensorStorage<T>),
-            dimensions(), offset(0), strides(), order(order)
-        {
-            dimensions.resize(3);
-            auto_strides();
-        }
-        Tensor(nested_initializer_list_t<T, 4> t)
-            :rep(new TensorStorage<T>),
-            dimensions(), offset(0), strides(), order(order)
-        {
-            dimensions.resize(4);
-            auto_strides();
-        }
-        Tensor(nested_initializer_list_t<T, 5> t)
-            :rep(new TensorStorage<T>),
-            dimensions(), offset(0), strides(), order(order)
-        {
-            dimensions.resize(5);
-            auto_strides();
         }
 
-        void reshape(std::initializer_list<idx_type> t)
+        T& index(const DimVector& dims)
         {
+            idx_type pos = 0;
+
+            for(idx_type i = 0; i < dimensions.size(); ++i)
+            {
+                pos += dimensions[i] * strides[i];
+            }
+
+            pos += offset;
 
+            return rep->data[pos];
         }
 
         Tensor& operator ()(idx_type idx)

+ 53 - 0
traph/source/interface/traph_tensor.i

@@ -2,9 +2,59 @@
 %{
     #include<traph/core/type.h>
     #include<traph/core/tensor.h>
+    #include<traph/core/index.h>
     using namespace traph;
 %}
 
+typedef float f32;
+typedef double f64;
+typedef std::int8_t i8;
+typedef std::int16_t i16;
+typedef std::int32_t i32;
+typedef std::int64_t i64;
+typedef std::uint8_t u8;
+typedef std::uint16_t u16;
+typedef std::uint32_t u32;
+typedef std::uint64_t u64;
+typedef i32 idx_type;
+typedef i32 size_type;
+
+%typemap(in) idx_type {
+  $1 = PyInt_AsLong($input);
+}
+
+%typemap(out) idx_type {
+  $result = PyInt_FromLong($1);
+}
+
+%typemap(in) size_type {
+  $1 = PyInt_AsLong($input);
+}
+
+%typemap(out) size_type {
+  $result = PyInt_FromLong($1);
+}
+
+class DimVector
+{
+private:
+    std::unique_ptr<idx_type[]> data;
+    idx_type stack_data[DIMVECTOR_SMALL_VECTOR_OPTIMIZATION];
+    idx_type dim_num;
+public:
+    DimVector();
+    DimVector(idx_type size);
+    DimVector(const DimVector& other);
+    DimVector(DimVector&& other);
+    DimVector& operator=(const DimVector& other) noexcept;
+    DimVector& operator=(DimVector&& other) noexcept;
+    void push_back(idx_type idx);
+    void resize(idx_type size);
+    idx_type size() const;
+    idx_type& operator[](idx_type dim);
+    idx_type operator[](idx_type dim) const;
+};
+
 template<class T>
 class TensorStorage
 {
@@ -32,6 +82,9 @@ public:
     Tensor(const DimVector& dimensions, layout_type order);
     Tensor(const DimVector& dimensions, const DimVector& strides);
     Tensor(const DimVector& dimensions, const DimVector& strides, layout_type order);
+
+    void reshape(const DimVector& dims);
+    T& index(const DimVector& dims);
 };
 
 %template(tensor_f32) Tensor<f32>;