PyTorch is the framework containing dependencies for building Deep Learning (DL) models using tensors to store data. Tensors are the multi-dimensional data structure for PyTorch to store and manage data in a uniform datatype. The data stored in the tensors is used to train the deep learning models, and the user can change the shape and size of the tensors after their creation.

Quick Outline

This guide explains the following sections:

How to Reshape Tensors in PyTorch?

Reshaping the tensors in PyTorch can be done using multiple methods offered by the torch framework to convert the tensors according to the user’s choice. Each method has a different working mechanism but they ultimately change the shape without affecting the contents stored in the tensors.

To learn the process of reshaping tensors in PyTorch, simply go through the list of steps mentioned below:

Prerequisites

Reshaping the tensors in PyTorch using the Python language needs the following steps to be completed:

  • Get Environment for Python
  • Install Modules
  • Import Libraries

If any of these steps are already done, simply move to the next one:

Get Environment for Python

Writing the script in Python programming language requires the creation of a Notebook from the official website:

Install Modules

On the Python Notebook, start the process by installing the torch framework using the following code:

pip install torch

Import Libraries

Import the library from the torch dependency to get its methods to complete the process:

import torch

Display the installed version of the torch library using the “torch._version_” command:

print(torch.__version__)

Method 1: Reshape PyTorch Tensor Using the reshape() Function

Torch library offers the use of the “reshape()” method to change the shape of an existing tensor in PyTorch and this section uses various approaches to illustrate the process:

The following examples take the dataset in one dimension and then change them to different shapes in all the following dimensions:

  1. Reshaping One-Dimensional (1D) Tensor
  2. Reshaping Two-Dimensional (2D) Tensor
  3. Reshaping Multi-Dimensional Tensor

Example 1: Reshaping One-Dimensional (1D) Tensor

Start the process by creating a tensor with only one dimension using the torch.tensor() method and store it in the tensor_a variable:

tensor_a = torch.tensor([1,2,3,4,5,6,7,8])

The shape() method is used to get the values from the tensor in PyTorch and the way the tensor is structured. Display the current shape of the tensor using the shape keyword with the name of the tensor:

print(tensor_a.shape)

Print the values stored in the tensor_a variable by calling its name and it can also be done by calling it in the print() method:

tensor_a

Use the reshape() with the dimensions in its argument and print the reshaped tensor:

print(tensor_a.reshape([8, 1]))

Example 2: Reshaping Two-Dimensional (2D) Tensor

The user can change the dimensions of the dataset as well using the reshape() method like this example converts the dataset from 1D to 2D:

tensor_a

Provide the arguments in the reshape() method with the name of the variable to display the dataset in two dimensions:

print(tensor_a.reshape([4, 2]))

Confirm that the dataset has been reshaped without changing the size by printing the values stored in the variable:

print(tensor_a.shape)

Example 3: Reshaping Multi-Dimensional Tensor

Convert the one-dimensional dataset to multi-dimensions using the reshape() method. Start the example by checking the original shape of the tensor:

tensor_a

Change the shape and convert it into a multi-dimensional tensor by giving values for each dimension in its argument:

print(tensor_a.reshape([2, 2, 2]))

Confirm that the size hasn’t been changed during the process:

print(tensor_a.shape)

Method 2: Reshape PyTorch Tensor Using the flatten() Function

Another function offered by the Torch library is the flatten() method to reshape the tensors in PyTorch. It simply flattens the tensor that is created in any dimension (n-dimension) into a single one and the following examples demonstrate the method:

  1. Reshaping Tensor From 2D to 1D 
  2. Reshaping Tensor From Multi-Dimensional to 1D

Example 1: Reshaping Tensor From 2D to 1D 

The following example creates the two-dimensional tensor and prints its values on the screen:

tensor_b = torch.tensor([[1,2,3,4,5,6,7,8], [1,2,3,4,5,6,7,8]])
print(tensor_b)

Use the flatten() method with the name of the variable as its argument to change the two-dimensional tensor into one-dimensional. The following code simply flattens the tensors and these values will be printed on the screen:

print(torch.flatten(tensor_b))

Example 2: Reshaping Tensor From Multi-Dimensional to 1D 

Make a multi-dimensional tensor and verify its success by printing its values on the screen:

tensor_c = torch.tensor([[[8,7,6,5,4,3,2,1], [11,12,13,14,15,16,17,18], [28,27,26,25,24,23,22,21], [31,32,33,34,35,36,37,38]]])

print(tensor_c)

Call the flatten() method with the name of the variable with the multi-dimensional structure to reshape its dimensions. It will change the shape of the tensor to a single dimension and print its values to verify the successful conversion of the dimensions:

print(torch.flatten(tensor_c))

Method 3: Reshape PyTorch Tensor Using view() Function

Another method used to reshape the tensors in PyTorch is called the view() method which simply converts the tensors into two dimensions. The user needs to give the values for both dimensions in the argument while calling the method. The user can get the values in one dimension by assigning one to either rows or columns which is explained in the following examples:

  1. Reshaping Tensor in 4×3
  2. Reshaping Tensor in 1D

Example 1: Reshaping Tensor in 4×3

Create a tensor with 12 values so all the values should be covered after converting into two dimensions without changing the size: 

tensor_d=torch.FloatTensor([24, 56, 10, 20, 30, 40, 50, 1, 2, 3, 4, 5])
print(tensor_d)

Change the dimensions of the tensor by giving the values to the rows and columns as mentioned in the following code:

print(tensor_d.view(4, 3))

Executing the above code prints the tensor in the 4×3 matrix:

Example 2: Reshaping Tensor in 1D 

The view() method only allows the conversion of a tensor into two dimensional. Still, the user can manipulate the arguments to view it in 1 dimension. The following code displays the tensor as 12 columns and 1 row:

print(tensor_d.view(12, 1))

The code used to convert the tensor into a single column with 12 rows as both of these are displayed in the screenshot after the execution of the code:

print(tensor_d.view(1, 12))

Method 4: Reshape PyTorch Tensor Using resize() Function

The fourth method to reshape the tensor is called the resize() function which changes the dimensions of the tensor. The resize() function can create multiple tensors from a single tensor as the following sections explain with multiple processes:

  1. Creating Four Multi-Dimensional Tensors From One Tensor 
  2. Creating 2 Tensors With 4 Rows and 2 Columns

Example 1: Creating Four Multi-Dimensional Tensors From One Tensor

Create a one-dimensional tensor and print its values on the screen:

tensor_e= torch.tensor([10, 20, 30, 40, 50, 1, 2, 3, 4, 5])
tensor_e

Use the resize() method to create 4 tensors from the existing one and each tensor should be of 4 rows and 5 columns:

print(tensor_e.resize_(4, 4, 5))

Executing the above code displays the successful use of the resize() method:

Example 2: Creating 2 Tensors With 4 Rows and 2 Columns 

Use the existing tensor to convert it into different dimensions as it is available in one dimension for now:

tensor_e

Create two tensors from the existing tensor and each tensor contains 4 rows and 2 columns with the values available in the tensor. Suppose the values are not enough for both the tensors. In that case, the rest will be filled with the zeros:

print(tensor_e.resize_(2, 4, 2))

Method 5: Reshape PyTorch Tensor Using unsqueeze() Function

The last method to reshape the tensors in PyTorch is using the unsqueeze() function to add new dimensions and the following example implements the method:

Example: Adding Dimensions to the Tensor

Create a one-dimensional tensor with the values stored in it and then print the shape of the tensor that shows its dimensions:

a = torch.Tensor([1, 2, 3, 4, 5]) 

print(a.shape)

Add a dimension at the index 0 which means that the shape should contain a 1 before the 5 making it a multi-dimensional tensor:

added = a.unsqueeze(0)
print(added.shape)

Now, add a dimension at the index 1 which means that 1 will be added after the 5 as the 5 is at the index 0 so the shape should return the 5,1 explaining its dimensions:

added = a.unsqueeze(1)
print(added.shape)

Note: The Python code for the complete guide can be found here

That’s all about the process of reshaping tensors in PyTorch.

Conclusion

To reshape tensors in PyTorch, simply install the torch and import its library to use the methods offered by the framework. The framework allows the user to reshape the tensors using multiple methods such as reshape(), flatten(), view(), resize(), and unsqueeze() functions. This guide has explained the process in detail using multiple examples using all the above-mentioned methods offered by the PyTorch framework.