tensor.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import pytraph.core.dtype
  2. import pytraph.core.traph_tensor
  3. class Storage(object):
  4. pass
  5. class Tensor(object):
  6. def __init__(self):
  7. self._inner_tensor = None
  8. def __str__(self):
  9. if self._inner_tensor is not None:
  10. return self._inner_tensor.to_string()
  11. else:
  12. return "None"
  13. class FloatTensor(Tensor):
  14. def __init__(self):
  15. self._inner_tensor = pytraph.core.traph_tensor.FloatTensor()
  16. def tensor(obj):
  17. if type(obj) == list:
  18. pass
  19. else:
  20. print('unsupported obj type')
  21. def zeros(shape):
  22. if type(shape) != tuple:
  23. raise RuntimeError('The type of shape shall be tuple.')
  24. ret = FloatTensor()
  25. dim = pytraph.core.traph_tensor.DimVector()
  26. for each in shape:
  27. dim.push_back(each)
  28. ret._inner_tensor = pytraph.core.traph_tensor.FloatTensor(dim)
  29. ret._inner_tensor.fill_(0)
  30. return ret
  31. def ones(shape):
  32. if type(shape) != tuple:
  33. raise RuntimeError('The type of shape shall be tuple.')
  34. ret = FloatTensor()
  35. dim = pytraph.core.traph_tensor.DimVector()
  36. for each in shape:
  37. dim.push_back(each)
  38. ret._inner_tensor = pytraph.core.traph_tensor.FloatTensor(dim)
  39. ret._inner_tensor.fill_(1)
  40. return ret