| Author | Pearu Peterson |
| Created | 2021-05-06 |
The aim of this blog post is to define the invariants of PyTorch tensors with CSR layout.
A tensor with CSR layout has the following members (as defined by constructor sparse_csr_tensor):
crow_indices contains the compressed row indices informationcol_indices contains column indicesvalues contains the values of tensor elementssize defines the shape of tensordtype defines the dtype of tensor elementslayout holds the layout parameterdevice holds the device of values storagepin_memory defines if cuda storage uses pinned memory1.1 crow_indices.dtype == indices_dtype
1.2 col_indices.dtype = indices_dtype
1.3 indices_dtype is int32 (default) or int64
1.4 values.dtype == dtype
1.5 dtype is float32 (default), or float64, or int8, …, or int64
2.1 crow_indices.layout == torch.strided
2.2 col_indices.layout == torch.strided
2.3 values.layout == torch.strided
2.4 layout == torch.sparse_csr
3.1 size == (nrows, ncols), that is, CSR tensor represents a 2 dimensional tensor
3.2 crow_indices.dim() == 1
3.3 col_indices.dim() == 1
3.4 values.dim() == 1
3.5 crow_indices.stride() == (1,) or crow_indices.is_contiguous()
3.6 col_indices.stride() == (1,) or col_indices.is_contiguous()
3.7 values.stride() == (1,) or values.is_contiguous()
3.8 crow_indices.size() == (nrows+1,) or crow_indices.numel() == nrows + 1
3.9 col_indices.size() == (nnz,) or col_indices.numel() == nnz
3.10 values.size() == (nnz,) or values.numel() == nnz
3.11 numel() == nrows * ncols is the number of indexable elements
4.1 device is CPU or CUDA
4.2 crow_indices.device == device
4.3 col_indices.device == device
4.4 values.device == device
5.1 crow_indices[0] == 0
5.2 crow_indices[nrows] == nnz
5.3 0 <= crow_indices[i] - crow_indices[i-1] <= ncols for all i=1,...,nrows
5.4 0 <= col_indices.min()
5.5 col_indices.max() < ncols
5.6 col_indices[crow_indinces[i-1]:crow_indinces[i]] must be sorted and with distinct values for all i=1,...,nrows (required by cuSparse)
According to PR 57274, creating a CSR tensor has the following function calling tree with the corresponding invariant checks:
set_member_tensors(crow_indices, col_indices, values, size)
crow_indices_ = crow_indicescol_indices_ = col_indicesvalues_ = valuesset_sizes(size)refresh_numel(), 3.11_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size)
new_csr_tensor()
_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, size)
new_csr_tensor()set_member_tensors(crow_indices, col_indices, values, size)sparse_csr_tensor(crow_indices, col_indices, values, size)
_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size)_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, size)sparse_csr_tensor(crow_indices, col_indices, values)
size = (crow_indices.numel() - 1, col_indices.max() + 1)_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size)_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, size)