23 const std::vector<Tensor>& index_tensors)
44 const std::vector<Tensor>& index_tensors);
49 const Tensor& tensor,
const std::vector<Tensor>& index_tensors);
53 static std::pair<std::vector<Tensor>,
SizeVector>
86 const std::vector<Tensor>& index_tensors);
122 const std::vector<Tensor>& index_tensors,
127 if (indexed_shape.
size() != indexed_strides.
size()) {
129 "Internal error: indexed_shape's ndim {} does not equal to "
130 "indexed_strides' ndim {}",
131 indexed_shape.
size(), indexed_strides.
size());
136 std::vector<Tensor> inputs;
137 inputs.push_back(src);
138 for (
const Tensor& index_tensor : index_tensors) {
139 if (index_tensor.NumDims() != 0) {
140 inputs.push_back(index_tensor);
148 "Internal error: indexed_shape's ndim {} does not equal to "
149 "indexd_strides' ndim {}",
160 "src's dtype {} is not the same as dst's dtype {}.",
184 int64_t index = *(
reinterpret_cast<int64_t*
>(
#define OPEN3D_HOST_DEVICE
Definition CUDAUtils.h:44
#define OPEN3D_ASSERT(...)
Definition Macro.h:48
This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp.
Definition AdvancedIndexing.h:20
void RunPreprocess()
Preprocess tensor and index tensors.
Definition AdvancedIndexing.cpp:110
static Tensor RestrideTensor(const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape)
Definition AdvancedIndexing.cpp:85
static bool IsIndexSplittedBySlice(const std::vector< Tensor > &index_tensors)
Definition AdvancedIndexing.cpp:17
static std::pair< Tensor, std::vector< Tensor > > ShuffleIndexedDimsToFront(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition AdvancedIndexing.cpp:41
SizeVector output_shape_
Output shape.
Definition AdvancedIndexing.h:96
SizeVector GetIndexedStrides() const
Definition AdvancedIndexing.h:38
std::vector< Tensor > index_tensors_
The processed index tensors.
Definition AdvancedIndexing.h:93
Tensor tensor_
Definition AdvancedIndexing.h:90
static std::vector< Tensor > ExpandBoolTensors(const std::vector< Tensor > &index_tensors)
Expand boolean tensor to integer index.
Definition AdvancedIndexing.cpp:230
static std::pair< std::vector< Tensor >, SizeVector > ExpandToCommonShapeExceptZeroDim(const std::vector< Tensor > &index_tensors)
Definition AdvancedIndexing.cpp:63
AdvancedIndexPreprocessor(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition AdvancedIndexing.h:22
static Tensor RestrideIndexTensor(const Tensor &index_tensor, int64_t dims_before, int64_t dims_after)
Definition AdvancedIndexing.cpp:100
std::vector< Tensor > GetIndexTensors() const
Definition AdvancedIndexing.h:30
SizeVector indexed_shape_
Definition AdvancedIndexing.h:100
SizeVector GetIndexedShape() const
Definition AdvancedIndexing.h:36
SizeVector GetOutputShape() const
Definition AdvancedIndexing.h:34
Tensor GetTensor() const
Definition AdvancedIndexing.h:28
SizeVector indexed_strides_
Definition AdvancedIndexing.h:104
Definition AdvancedIndexing.h:116
Indexer indexer_
Definition AdvancedIndexing.h:197
AdvancedIndexer(const Tensor &src, const Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides, AdvancedIndexerMode mode)
Definition AdvancedIndexing.h:120
int64_t element_byte_size_
Definition AdvancedIndexing.h:200
int64_t NumWorkloads() const
Definition AdvancedIndexing.h:194
int64_t num_indices_
Definition AdvancedIndexing.h:199
OPEN3D_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
Definition AdvancedIndexing.h:173
int64_t indexed_strides_[MAX_DIMS]
Definition AdvancedIndexing.h:202
AdvancedIndexerMode mode_
Definition AdvancedIndexing.h:198
int64_t indexed_shape_[MAX_DIMS]
Definition AdvancedIndexing.h:201
OPEN3D_HOST_DEVICE int64_t GetIndexedOffset(int64_t workload_idx) const
Definition AdvancedIndexing.h:181
AdvancedIndexerMode
Definition AdvancedIndexing.h:118
OPEN3D_HOST_DEVICE char * GetInputPtr(int64_t workload_idx) const
Definition AdvancedIndexing.h:166
std::string ToString() const
Definition Dtype.h:64
int64_t ByteSize() const
Definition Dtype.h:58
OPEN3D_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
Definition Indexer.h:437
int64_t NumWorkloads() const
Definition Indexer.cpp:406
OPEN3D_HOST_DEVICE char * GetInputPtr(int64_t input_idx, int64_t workload_idx) const
Definition Indexer.h:405
Definition SizeVector.h:69
size_t size() const
Definition SmallVector.h:119
Dtype GetDtype() const
Definition Tensor.h:1163
Definition PinholeCameraIntrinsic.cpp:16