Browse Source

add slice

JasonWang 7 years ago
parent
commit
e7f06cb0f1

+ 21 - 0
traph/include/traph/core/arithmetic.h

@@ -0,0 +1,21 @@
+#ifndef TRAPH_ARITHMETIC_H_
+#define TRAPH_ARITHMETIC_H_
+
+#include <utility>
+
+#include <traph/core/type.h>
+#include <traph/core/index.h>
+#include <traph/core/utils.h>
+#include <traph/core/tensor.h>
+
+namespace traph
+{
+    template<class T>
+    Tensor<T> add(const Tensor<T> &t, T v)
+    {
+        
+    }
+
+}
+
+#endif

+ 17 - 0
traph/include/traph/core/autograd.h

@@ -0,0 +1,17 @@
+#ifndef TRAPH_AUTOGRAD_H_
+#define TRAPH_AUTOGRAD_H_
+
+#include <vector>
+
+namespace traph
+{
+    class Graph
+    {
+    private:
+    public:
+
+    };
+}
+
+
+#endif

+ 45 - 0
traph/include/traph/core/slice.h

@@ -0,0 +1,45 @@
+#ifndef TRAPH_SLICE_H_
+#define TRAPH_SLICE_H_
+
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include <traph/core/type.h>
+#include <traph/core/index.h>
+#include <traph/core/utils.h>
+#include <traph/core/tensor.h>
+
+namespace traph
+{
+    class BasicSlice
+    {
+    public:
+        idx_type start;
+        idx_type step;
+        idx_type end;
+    };
+
+    class AdvancedSlice
+    {
+    public:
+        std::vector<idx_type> indices;
+    };
+
+    enum SliceMode
+    {
+        BASIC,
+        ADVANCED
+    };
+
+    class Slice
+    {
+    public:
+        std::variant<BasicSlice, AdvancedSlice> slice;
+        SliceMode mode;
+    };
+
+    using SliceVector = std::vector<Slice>;
+}
+
+#endif

+ 38 - 25
traph/include/traph/core/tensor.h

@@ -140,6 +140,8 @@ namespace traph
         idx_type offset;
 		DimVector strides;
         layout_type order;
+
+        bool requires_grad;
     private:
         void auto_strides()
         {
@@ -174,41 +176,41 @@ namespace traph
     public:
         Tensor()
             :rep(new TensorStorage<T>),
-            dimensions(), offset(0), strides(), order(layout_type::column_major)
+            dimensions(), offset(0), strides(), order(layout_type::column_major), requires_grad(false)
         {
         }
 
         explicit Tensor(const DimVector& dimensions)
             :rep(new TensorStorage<T>),
-            dimensions(dimensions), offset(0), strides(), order(layout_type::column_major)
+            dimensions(dimensions), offset(0), strides(), order(layout_type::column_major), requires_grad(false)
         {
             auto_strides();
         }
 
         explicit Tensor(const DimVector& dimensions, layout_type order)
             :rep(new TensorStorage<T>),
-            dimensions(dimensions), offset(0), strides(), order(order)
+            dimensions(dimensions), offset(0), strides(), order(order), requires_grad(false)
         {
             auto_strides();
         }
 
         explicit Tensor(const DimVector& dimensions, const DimVector& strides)
             :rep(new TensorStorage<T>),
-            dimensions(dimensions), offset(0), strides(strides), order(layout_type::column_major)
+            dimensions(dimensions), offset(0), strides(strides), order(layout_type::column_major), requires_grad(false)
         {
             auto_strides();
         }
 
         explicit Tensor(const DimVector& dimensions, const DimVector& strides, layout_type order)
             :rep(new TensorStorage<T>),
-            dimensions(dimensions), offset(0), strides(strides), order(order)
+            dimensions(dimensions), offset(0), strides(strides), order(order), requires_grad(false)
         {
             auto_strides();
         }
 
         Tensor(const T& t)
             :rep(new TensorStorage<T>),
-            dimensions(), offset(0), strides(), order(order)
+            dimensions(), offset(0), strides(), order(order), requires_grad(false)
         {
             dimensions.resize(1);
             auto_strides();
@@ -219,7 +221,8 @@ namespace traph
             dimensions(other.dimensions),
             offset(other.offset),
             strides(other.strides),
-            order(other.order)
+            order(other.order),
+            requires_grad(other.requires_grad)
         {
         }
 
@@ -228,7 +231,8 @@ namespace traph
             dimensions(other.dimensions),
             offset(other.offset),
             strides(other.strides),
-            order(other.order)
+            order(other.order),
+            requires_grad(other.requires_grad)
         {
         }
 
@@ -244,6 +248,32 @@ namespace traph
             return result;
         }
         // op
+        void add_(T value)
+        {
+            idx_type i = offset;
+            for(idx_type dim = 0;dim < dimensions.size();++dim)
+            {
+                for(idx_type step = 0; step < dimension[dim];++step)
+                {
+                    rep->data[i] = rep->data[i] + value;
+                    i += strides[dim];
+                }
+            }
+        }
+
+        void fill_(T value)
+        {
+            idx_type i = offset;
+            for(idx_type dim = 0;dim < dimensions.size();++dim)
+            {
+                for(idx_type step = 0; step < dimension[dim];++step)
+                {
+                    rep->data[i] = value;
+                    i += strides[dim];
+                }
+            }
+        }
+
         void abs_()
         {
             idx_type i = offset;
@@ -283,22 +313,5 @@ namespace traph
 
             return rep->data[pos];
         }
-
-        Tensor& operator ()(idx_type idx)
-        {
-            
-        }
-
-        template<class... Args>
-        Tensor& operator ()(idx_type idx, Args... args)
-        {
-            
-        }
-
-        template<class... Args>
-        const Tensor& operator ()(idx_type idx, Args... args) const
-        {
-            
-        }
     };
 }

+ 16 - 0
traph/include/traph/core/view.h

@@ -0,0 +1,16 @@
+#ifndef TRAPH_VIEW_H_
+#define TRAPH_VIEW_H_
+
+#include <utility>
+
+#include <traph/core/type.h>
+#include <traph/core/index.h>
+#include <traph/core/utils.h>
+#include <traph/core/tensor.h>
+
+namespace traph
+{
+    
+}
+
+#endif

+ 4 - 1
traph/source/test/main.cpp

@@ -2,6 +2,9 @@
 
 int main()
 {
-    traph::Tensor<float> t;
+    traph::Tensor<float> a = traph::zeros({2, 2});
+    traph::Tensor<float> w = traph::ones({3, 2});
+    traph::Tensor<float> result = traph::matmul(w, a);
+    result.backward();
     return 0;
 }