jstzwj 7 лет назад
Родитель
Сommit
e616b2e00a

+ 1 - 1
.vscode/settings.json

@@ -1,5 +1,5 @@
 {
-    "python.pythonPath": "C:\\Users\\jstzw\\AppData\\Local\\Programs\\Python\\Python36\\python.exe",
+    "python.pythonPath": "C:\\ProgramData\\Anaconda3\\python.exe",
     "files.associations": {
         "algorithm": "cpp",
         "array": "cpp",

+ 23 - 1
traph/include/traph/core/tensor.h

@@ -32,8 +32,30 @@ namespace traph
         virtual idx_type size() const = 0;
     };
 
+    class TensorInterface
+    {
+    public:
+        using TensorInterfacePtr = std::shared_ptr<TensorInterface>;
+        using TensorInterfaceRef = TensorInterface&;
+        using TensorInterfaceConstRef = const TensorInterface&;
+
+    public:
+        virtual device_id device() = 0;
+        virtual idx_type offset() const = 0;
+		virtual layout_type order() const = 0;
+        virtual platform_type platform() = 0;
+        virtual void reshape_(const DimVector& dims) = 0;
+        virtual void resize_(const DimVector& dims) = 0;
+		virtual DimVector size() const = 0;
+		virtual DimVector stride() const = 0;
+    };
+
+    using TensorInterfacePtr = std::shared_ptr<TensorInterface>;
+    using TensorInterfaceRef = TensorInterface&;
+    using TensorInterfaceConstRef = const TensorInterface&;
+
     template<class T>
-    class TensorBase
+    class TensorBase: public TensorInterface
     {
     public:
         using TensorBasePtr = std::shared_ptr<TensorBase<T>>;

+ 24 - 0
traph/include/traph/core/variable.h

@@ -8,6 +8,30 @@
 
 namespace traph
 {
+    class VariableInterface
+    {
+    public:
+        using VariableInterfacePtr = std::shared_ptr<VariableInterface>;
+        using VariableInterfaceRef = VariableInterface&;
+        using VariableInterfaceConstRef = const VariableInterface&;
+
+    public:
+        virtual void backward() = 0;
+        virtual device_id device() = 0;
+        virtual idx_type offset() const = 0;
+		virtual layout_type order() const = 0;
+        virtual platform_type platform() = 0;
+        virtual void requires_grad_(bool requires_grad) = 0;
+        virtual void reshape_(const DimVector& dims) = 0;
+        virtual void resize_(const DimVector& dims) = 0;
+		virtual DimVector size() const = 0;
+		virtual DimVector stride() const = 0;
+    };
+
+    using VariableInterfacePtr = std::shared_ptr<VariableInterface>;
+    using VariableInterfaceRef = VariableInterface&;
+    using VariableInterfaceConstRef = const VariableInterface&;
+
     template<class T>
     class VariableBase
     {

+ 55 - 2
traph/include/traph/nn/operation.h

@@ -4,17 +4,70 @@
 #include <utility>
 #include <cmath>
 #include <string>
+#include <vector>
+#include <memory>
+#include <cassert>
 
 #include <traph/core/type.h>
 #include <traph/core/index.h>
 #include <traph/core/utils.h>
 #include <traph/core/variable.h>
-#include <traph/tensor/tensor.h>
 #include <traph/nn/variable.h>
+#include <traph/core/tensor.h>
+#include <traph/tensor/tensor.h>
+#include <traph/nn/graph.h>
 
 namespace traph
 {
-    
+    class OpContext
+    {
+    private:
+        std::vector<TensorInterfacePtr> _saved_tensors;
+    public:
+        void save(TensorInterfacePtr tensor)
+        {
+            _saved_tensors.push_back(tensor);
+        }
+
+        std::vector<TensorInterfacePtr> get_saved_tensors() const
+        {
+            return _saved_tensors;
+        }
+    };
+
+    class OpBase
+    {
+    public:
+        OpContext context;
+    };
+
+    template<class T>
+    class OpInterface: public OpBase
+    {
+    public:
+        virtual TensorBasePtr<T> forward(std::vector<TensorBasePtr<T>> inputs) = 0;
+        virtual std::vector<TensorBasePtr<T>> backward(TensorBasePtr<T> output_grad) = 0;
+    };
+
+    template<class T>
+    class SumOp: public OpInterface<T>
+    {
+    public:
+        virtual TensorBasePtr<T> forward(std::vector<TensorBasePtr<T>> inputs) override
+        {
+            assert(inputs.size() == 1);
+            
+            TensorBasePtr<T> input = inputs[0];
+            TensorBasePtr<T> result = input->sum();
+
+			return result;
+        }
+
+        virtual std::vector<TensorBasePtr<T>> backward(TensorBasePtr<T> output_grad) override
+        {
+            return {output_grad};
+        }
+    };
 }
 
 #endif

+ 5 - 2
traph/include/traph/nn/variable.h

@@ -4,11 +4,13 @@
 #include <memory>
 #include <functional>
 #include <initializer_list>
+#include <vector>
 
 #include <traph/core/index.h>
 #include <traph/core/tensor.h>
 #include <traph/core/variable.h>
 #include <traph/tensor/tensor.h>
+#include <traph/nn/operation.h>
 
 namespace traph
 {
@@ -32,11 +34,12 @@ namespace traph
         std::shared_ptr<TensorBase<T>> _grad;
         bool _requires_grad;
         bool _leaf;
-        std::function<TensorBasePtr<T>(TensorBasePtr<T>)> _grad_fn;
+        std::shared_ptr<OpInterface<T>> _grad_fn;
+        std::vector<VariableInterface> _inputs;
     public:
         Variable()
             :_data(new Tensor<T>), _grad(nullptr),
-            _requires_grad(false), _leaf(is_leaf),
+            _requires_grad(false), _leaf(false),
             _grad_fn(nullptr)
         {