pearu.github.io

Tensor creation operations in PyTorch

   
Author Pearu Peterson
Created 2021-03-22

The aim of this blog post is to propose a classification of PyTorch tensor creation operations and seek for the corresponding testing patterns.

This blog post is inspired by the new PyTorch testing framework that introduces the OpInfo pattern to simplify writing tests for PyTorch. However, it does not yet provide a solution to the problem of testing tensor creation operations that would be required in PR 54187, for instance.

Introduction

In general, Tensor instances can be created from other Tensor instances as a result of tensor operations. But not only. Here we consider tensor creation operations that inputs can be arbitrary Python objects from which new Tensor instances can be constructed. For instance, tensors can be constructed from objects that implement Array Interface protocol, or from Python sequences that represent array-like structures, or from Python integers as in the case of torch.zeros, torch.arange, etc. The created tensors may or may not share the memory with the input objects depending on the particular operation as well as on the used device or dtype parameter values.

Tensor creation operations

To distinguish tensor creation operations from other operations, we define the tensor creation operations as operations that result a Tensor instance with user-specified

According to this definition, PyTorch implements the following tensor creation operations:

Notes

Testing tensor creation operations

In general, testing of an operation for a correct behaviour (as specified in its documentation, or defined by mathematics, or determined by some supported protocol, etc) involves checking if a given input produces an expected result. In the case of tensor creation operations, the input parameters define not only the content of tensor data buffers (dtype, values, shape) but also how and where the data is stored in memory (layout, strides, device, input memory shared or not, etc). Different from OpInfo framework goals, testing with respect to Autograd support is mostly irrelevant as the inputs to tensor creation operations are not PyTorch Tensor objects (except for few cases like torch.sparse_coo_tensor, torch.complex, etc, and for view operations).

Let Op(...) -> Tensor be a tensor creation operation. There exists a number of patterns that can be defined for testing the correctness of Op:

  1. Specifying a tensor property in an argument list must result in a tensor that has this property:

    Op(..., dtype=dtype).dtype == dtype
    Op(..., device=device).device == device
    Op(..., layout=layout).layout == layout
    Op_like(input, ...).dtype == input.dtype
    Op_like(input, ...).device == input.device
    Op_like(input, ...).layout == input.layout
    

    whereas the explict definition of a property in Op_like argument list overrides the corresponding property of input tensor.

  2. Specification of out parameter will lead to the same result as in the case of unspecified out parameter, unless the out argument specification is erroneous (e.g. it has a different dtype from the dtype of an expected result):

    Op(*args, out=a, **kwargs) == a == Out(*args, **kwargs)
    
  3. For Tensor creation operations that have NumPy analogies, such as zeros, ones, arange, etc, use NumPy functions as reference implementations:

    torch.Op(...).numpy() == numpy.Op(...)
    

    Warnings: numpy.Op may have a different user-interface from the corresponding torch.Op one. Using NumPy functions as reference is based on the assumption that NumPy functions behave correctly. Seldomly, the definitions of correctness may vary in between projects.

  4. The content of a tensor does not depend on the used storage device, data storage layout, nor memory format:

    torch.Op(..., device=device).to(device='cpu') == torch.Op(..., device='cpu')
    torch.Op(..., layout=layout).to_dense() == torch.Op(..., layout=torch.strided)
    torch.Op(..., memory_format=memory_format).to(memory_format=torch.contiguous_format) == torch.Op(..., memory_format=torch.contiguous_format)
    
  5. Tensors created using pin_memory=True must be accessible from CUDA device:

    class W:
        def __init__(self, tensor):
           self.__cuda_array_interface__ = tensor.__array__.__array_interface__
    
    x = torch.Op(..., pin_memory=True)
    assert x.device == 'cpu'
    t = torch.as_tensor(W(x), device='cuda')
    assert w.device.startswith('cuda')
    x += 1                                   # modify x in-place
    assert (t.to(device='cpu') == x).all()   # changes in x are reflected in t
    
  6. When Op represents a random tensor creation operation, its correctness must be verified using statistical methods (e.g. by computing the statistical moments of results, and comparing these with the expected values, approximately).

  7. Tensor constructor torch.tensor must be able to construct a tensor from the following objects:

    • NumPy ndarray objects
    • nested sequences of numbers
    • objects implementing CPU/CUDA Array Interface protocols (as in PR 54187)
    • objects implementing PEP 3118 Buffer protocol

    whereas the resulting tensor must not share the memory with the input data buffer.

  8. Tensor constructor torch.as_tensor must be able to construct a tensor from the following objects:

    • PyTorch Tensor instance
    • NumPy ndarray objects
    • nested sequences of numbers
    • objects implementing CPU/CUDA Array Interface protocols (as in PR 54187)
    • objects implementing PEP 3118 Buffer protocol

    whereas the resulting tensor may share the memory with the input data buffer.

  9. All view operations must be tested by modifying the view result, and then checking if the corresponding changes appear in the original tensor:

    a = Op(x)  # create a view of x
    a += 1     # modify the view in-place
    a == Op(x) # recreating the view gives the modified view
    

The current state of testing tensor creation operations in PyTorch

Clearly, many of the tensor creation operations are heavily used in PyTorch testing suite for creating inputs to various tests of PyTorch functionality. However, when extending the functionality of tensor creation operations, it is not always obvious where the corresponding tests should be implemented. Also, not all tensor creation operations are systematically tested.

For instance, let us consider torch.as_tensor operation. It has unit-tests implemented in test/test_tensor_creation_ops.py:TestTensorCreation.test_as_tensor() that covers the following parameter cases:

test/test_tensor_creation_ops.py:TestTensorCreation.test_tensor_ctor_device_inference() covers torch.as_tensor(..., device=device).device == device for CPU and CUDA devices and float32/64 dtype.

test/test_tensor_creation_ops.py:TestTensorCreation.test_tensor_factories_empty() covers creating an empty tensor from list via torch.as_tensor().

test/test_numba_integration.py:TestNumbaIntegration.test_cuda_array_interface covers testing CUDA Array Interface via torch.as_tensor().

Finally, when implementing PR 54187, I noticed that torch.tensor was creating a non-copy Tensor instance from an object that implements Array Interface. The operation torch.tensor should always copy the data buffers but the current test-suite did not catch the non-copy behavior while ideally it should have. It means that not all requirements specified in the torch.tensor documentation have been tested.