jstzwj 7 år sedan
förälder
incheckning
a8cd5f7025

+ 4 - 0
traph/include/traph/core/tensor.h

@@ -2,9 +2,12 @@
 #define TRAPH_CORE_TENSOR_H_
 
 #include <traph/core/type.h>
+#include <traph/core/index.h>
+
 
 namespace traph
 {
+    template<class T>
     class ContiguousStorageBase
     {
     public:
@@ -14,6 +17,7 @@ namespace traph
         virtual void resize_(idx_type size) = 0;
     };
 
+    template<class T>
     class TensorBase
     {
     public:

+ 6 - 1
traph/include/traph/nn/variable.h

@@ -1,11 +1,16 @@
 #ifndef TRAPH_NN_VARIABLE_H_
 #define TRAPH_NN_VARIABLE_H_
 
+#include <traph/core/tensor.h>
+
 namespace traph
 {
+    template<class T>
     class Variable
     {
-        
+    private:
+        TensorBase<T> _data;
+        TensorBase<T> _grad;
     };
 }
 

+ 3 - 2
traph/include/traph/tensor/tensor.h

@@ -8,12 +8,13 @@
 #include<traph/core/type.h>
 #include<traph/core/index.h>
 #include<traph/core/utils.h>
+#include<traph/core/tensor.h>
 
 namespace traph
 {
     // The real representation of all tensors.
     template<typename T>
-    class TensorStorage
+    class TensorStorage: public ContiguousStorageBase<T>
     {
     public:
         using DoubleStorage = TensorStorage<f64>;
@@ -130,7 +131,7 @@ namespace traph
 
     // ndarray
     template<typename T>
-    class Tensor
+    class Tensor: public TensorBase<T>
     {
     private:
         std::unique_ptr<TensorStorage<T>> _rep;