|
|
@@ -40,25 +40,6 @@ namespace traph
|
|
|
virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) = 0;
|
|
|
};
|
|
|
|
|
|
- class SumOp: public OpBase
|
|
|
- {
|
|
|
- public:
|
|
|
- virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
|
|
|
- {
|
|
|
- assert(inputs.size() == 1);
|
|
|
-
|
|
|
- TensorInterfacePtr input = inputs[0];
|
|
|
- TensorInterfacePtr result = input->sum();
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
-
|
|
|
- virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
|
|
|
- {
|
|
|
- return {output_grad};
|
|
|
- }
|
|
|
- };
|
|
|
-
|
|
|
class AddOp : public OpBase
|
|
|
{
|
|
|
public:
|
|
|
@@ -191,6 +172,7 @@ namespace traph
|
|
|
|
|
|
virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
|
|
|
{
|
|
|
+ // fixme: bug
|
|
|
TensorBasePtr<f32> result = std::dynamic_pointer_cast<TensorBase<f32>>(output_grad->clone());
|
|
|
result->cos_();
|
|
|
return { result };
|
|
|
@@ -221,6 +203,25 @@ namespace traph
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+ class SumOp: public OpBase
|
|
|
+ {
|
|
|
+ public:
|
|
|
+ virtual TensorInterfacePtr forward(std::vector<TensorInterfacePtr> inputs) override
|
|
|
+ {
|
|
|
+ assert(inputs.size() == 1);
|
|
|
+
|
|
|
+ TensorInterfacePtr input = inputs[0];
|
|
|
+ TensorInterfacePtr result = input->sum();
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ virtual std::vector<TensorBasePtr<f32>> backward(TensorBasePtr<f32> output_grad) override
|
|
|
+ {
|
|
|
+ return {output_grad};
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
class TransposeOp : public OpBase
|
|
|
{
|
|
|
private:
|