In PyTorch, tensors are the fundamental data structure that can have multiple dimensions. Sometimes, users may want to determine the k-th and top “k” elements of a specific tensor in various machine learning, data analysis, and deep learning tasks. PyTorch provides “torch.kthvalue()” and “torch.topk()” methods to perform this operation.

This blog will illustrate:

  • How to Find Tensor’s k-th Element in PyTorch?
  • How to Find Tensor’s Top “k” Elements in PyTorch?

How to Find Tensor’s k-th Element in PyTorch?

The “k-th” element in a tensor is the element that is located at a particular (k-th) position after sorting the tensor. To determine the k-th element of a specific tensor in PyTorch, the “torch.kthvalue()” method/function is used. This method sorts the tensor in ascending order and returns the specific kth element value from the sorted tensor.

Follow the provided steps for its practical implementation:

Step 1: Import PyTorch Library

First, install the “torch” library to work with tensors:

import torch

Step 2: Define a Tensor

Then, define a desired tensor and display its elements. Here, we are defining the following 1D “T1” tensor from a list using the “torch.Tensor()” function:

T1 = torch.Tensor([5, 0, -3, 7, 9, 2, 13])

print(T1)

This has created a “T1” tensor as seen below:

Step 3: Find k-th Element of the Tensor 

Now, use the “torch.kthvalue()” method and specify the input tenor and the value to find the specific element in the sorted tenor. For instance, we are finding the 3rd sorted element of the “T1” tensor:

val, ind = torch.kthvalue(T1, 3)

Here,

  • val” is a variable that contains the kth element value from the sorted tensor.
  • ind” variable contains the element’s index in the original tensor.

Step 4: Print Value and Index of the k-th Element 

Finally, display the index and value of the k-th element of the “T1” tensor:

print("\nIndex:", ind, "\nValue:", val)

The below output shows that the k-th (3rd) element in the sorted “T1” tensor is at the 5th index and its value is “2”:

Similarly, if users want to find the 5th sorted element of the “T1” tensor, they can specify the input tensor and the value in the “torch.kthvalue()” method: 

val, ind = torch.kthvalue(T1, 5)

print("\nIndex:", ind, "\nValue:", val)

In the below output, the index and value of the 5th element in the sorted tensor can be seen:

How to Find Tensor’s top “k” Elements in PyTorch?

The top “k” elements are the largest elements in a particular PyTorch tensor. To get the top “k” Elements of a Tensor in PyTorch, the “torch.topk()” method is used. This returns the desired (k) largest elements and their corresponding indices from the input tensor. 

Step 1: Install PyTorch Library

First, install the “torch” library:

import torch

Step 2: Define a Tensor

Then, define a desired tensor and print its elements. Here, we are defining the following 1D “T2” tensor from a list using the “torch.Tensor()” function:

T2 = torch.Tensor([7.144, 3.543, -9.398, -0.665, 4, 10.921])

print(T2)

This has created the “T2” tensor:

Step 3: Find Top “k” Elements of Tensor

Now, utilize the “torch.topk()” method and specify the input tensor and desired value to find the largest (top) elements of the input tensor. For instance, we are finding the top 3 (largest) elements in the “T2” tensors and their indices:

val, ind = torch.topk(T2, 3)

Here,

  • val” is a variable that contains the top “k” (3) elements values of the “T2” tensor.
  • ind” variable contains the element’s indices in the original tensor.

Step 4: Print Values and Indices of the Top “k” Elements

Finally, display indices and values of the top “k” elements of the “Tens2” tensor:

print("Top 3 element values:", val)

print("Top 3 element indices:", ind)

The below output displays the top 3 (largest) elements and their indices:

We have efficiently explained the methods of finding the k-th and the top “k” elements of a tensor in PyTorch.

Note: Click on the provided link to access our Google Colab Notebook.

Conclusion

To find the kth and top (largest) “k” elements of a specific PyTorch tensor, first, install the “torch” library. Then, create a desired tensor and print its elements. Next, use the “torch.kthvalue()” and “torch.topk()” methods and specify the input tenor and the desired value to find the specific element in the sorted tenor and the largest elements in the input tensor respectively. Finally, display the value, index, and indices of the k-th and top “k” elements of the tensor. This blog has illustrated the method to find the tensor’s k-th and top “k” elements in PyTorch.