|
|
@@ -144,6 +144,66 @@ namespace traph
|
|
|
Tensor<f32> matmul(const Tensor<f32> &a, const Tensor<f32> &b);
|
|
|
|
|
|
Tensor<f64> matmul(const Tensor<f64> &a, const Tensor<f64> &b);
|
|
|
+
|
|
|
+ template<class T>
|
|
|
+ Tensor<T> mean(const Tensor<T> &input)
|
|
|
+ {
|
|
|
+ T result{};
|
|
|
+ idx_type flat_size = input.size().flat_size();
|
|
|
+ idx_type offset = input.offset();
|
|
|
+ const T *input_data = input.data();
|
|
|
+ for (idx_type i = offset; i < offset + flat_size; ++i)
|
|
|
+ {
|
|
|
+ result += input_data[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ return result / flat_size;
|
|
|
+ }
|
|
|
+
|
|
|
+ template<class T>
|
|
|
+ Tensor<T> mul(const Tensor<T> &input, T value)
|
|
|
+ {
|
|
|
+ Tensor<T> result(input.size());
|
|
|
+ idx_type flat_size = input.size().flat_size();
|
|
|
+ idx_type offset = input.offset();
|
|
|
+ const T *input_data = input.data();
|
|
|
+ T *result_data = result.data();
|
|
|
+ for (idx_type i = 0; i < flat_size; ++i)
|
|
|
+ {
|
|
|
+ result_data[i] += input_data[i + offset] + value;
|
|
|
+ }
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ template<class T>
|
|
|
+ void mul_check(const Tensor<T> &input, const Tensor<T> & other)
|
|
|
+ {
|
|
|
+ if(!strict_same_shape(input, other))
|
|
|
+ throw std::runtime_error("mul: Two tensor must have the same shape or be broadcastable.");
|
|
|
+ }
|
|
|
+
|
|
|
+ template<class T>
|
|
|
+ Tensor<T> mul(const Tensor<T> &input, const Tensor<T> & other)
|
|
|
+ {
|
|
|
+ // check
|
|
|
+ mul_check(input, other);
|
|
|
+
|
|
|
+ Tensor<T> result(input.size());
|
|
|
+ idx_type flat_size = input.size().flat_size();
|
|
|
+ idx_type input_offset = input.offset();
|
|
|
+ idx_type other_offset = other.offset();
|
|
|
+ const T *input_data = input.data();
|
|
|
+ const T *other_data = other.data();
|
|
|
+ T *result_data = result.data();
|
|
|
+
|
|
|
+ for (idx_type i = 0; i < flat_size; ++i)
|
|
|
+ {
|
|
|
+ result_data[i] += input_data[i + input_offset] + other_data[i + other_offset];
|
|
|
+ }
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#endif
|