tensor.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import pytraph.core.dtype
  2. import pytraph.core.traph_tensor
  3. class Storage(object):
  4. pass
  5. class Tensor(object):
  6. def __init__(self):
  7. self._inner_tensor = None
  8. def __str__(self):
  9. if self._inner_tensor is not None:
  10. return self._inner_tensor.to_string()
  11. else:
  12. return "None"
  13. def __getitem__(self, given):
  14. slice_vector = pytraph.core.traph_tensor.SliceVector()
  15. if isinstance(given, slice):
  16. slice_vector.push_back(pytraph.core.traph_tensor.Slice(given.start, given.step, given.stop))
  17. elif isinstance(given, tuple):
  18. for each_slice in given:
  19. if isinstance(given, slice):
  20. slice_vector.push_back(pytraph.core.traph_tensor.Slice(each_slice.start, each_slice.step, each_slice.stop))
  21. else:
  22. slice_vector.push_back(pytraph.core.traph_tensor.Slice(each_slice, 1, each_slice+1))
  23. else:
  24. slice_vector.push_back(pytraph.core.traph_tensor.Slice(given, 1, given+1))
  25. return self._inner_tensor.select(slice_vector)
  26. def __setitem__(self,key,value):
  27. self.dict[key] = value
  28. class FloatTensor(Tensor):
  29. def __init__(self):
  30. self._inner_tensor = pytraph.core.traph_tensor.FloatTensor()
  31. def tensor(obj):
  32. if type(obj) == list:
  33. pass
  34. else:
  35. print('unsupported obj type')
  36. def zeros(shape):
  37. if type(shape) != tuple:
  38. raise RuntimeError('The type of shape shall be tuple.')
  39. ret = FloatTensor()
  40. dim = pytraph.core.traph_tensor.DimVector()
  41. for each in shape:
  42. dim.push_back(each)
  43. ret._inner_tensor = pytraph.core.traph_tensor.FloatTensor(dim)
  44. ret._inner_tensor.fill_(0)
  45. return ret
  46. def ones(shape):
  47. if type(shape) != tuple:
  48. raise RuntimeError('The type of shape shall be tuple.')
  49. ret = FloatTensor()
  50. dim = pytraph.core.traph_tensor.DimVector()
  51. for each in shape:
  52. dim.push_back(each)
  53. ret._inner_tensor = pytraph.core.traph_tensor.FloatTensor(dim)
  54. ret._inner_tensor.fill_(1)
  55. return ret