PyTorch – argmin()
A tensor is a multidimensional array that is used to store data. So to use a tensor, we have to import the torch module.
To create a tensor the method used is tensor().
Syntax:
torch.tensor(data)
Where data is a multi-dimensional array.
argmin()
argmin() in PyTorch is used to return the index of the minimum value of all elements in the input tensor.
Syntax:
torch.argmin(tensor,dim,keepdim)
Where
- The tensor is the input tensor.
- dim is to reduce the dimension. dim=0 specifies column comparison, which will get the index for the minimum value along a column, and dim=1 specifies row comparison, which will get the index for the minimum value along the row.
- keepdim checks whether the output tensor has dimension(dim) retained or not.
Example 1:
In this example, we will create a tensor with two dimensions that has three rows and five columns and apply argmin() on the rows and columns.
import torch
#create a tensor with 2 dimensions (3 * 5)
#with random elements using randn() function
data = torch.randn(3,5)
#display
print(data)
#get minimum index along columns with argmin
print(torch.argmin(data, dim=0))
#get minimum index along rows with argmin
print(torch.argmin(data, dim=1))
Output:
[-1.2597, -0.3892, 0.2120, 0.1376, 0.6919],
[ 0.0449, -0.3545, -0.1914, 0.1969, -2.0053]])
tensor([1, 1, 2, 0, 2])
tensor([3, 0, 4])
As we can see, the minimum values of the indexes and columns are:
- Min value – -1.2597. Its index is 1.
- Min value – 1 -0.3892. Its index is 1.
- Min value – -0.1914. Its index is 2.
- Min value – 0.4714. Its index is 0.
- Min value – -2.0053. Its index is 2.
Similarly, the minimum values present at the index along rows are:
- Min value – -0.4714. Its index is 3.
- Min value – -1.2597. Its index is 0.
- Min value – -2.0053. Its index is 4.
Example 2:
Create a tensor with a five by five matrix and apply argmin().
import torch
#create a tensor with 2 dimensions (5 * 5)
#with random elements using randn() function
data = torch.randn(5,5)
#display
print(data)
#get minimum index along columns with argmin
print(torch.argmin(data, dim=0))
#get minimum index along rows with argmin
print(torch.argmin(data, dim=1))
Output:
[ 0.2564, -0.3471, 1.5256, -1.1608, 0.4367],
[ 1.4390, -0.5474, 0.5909, 0.0491, 0.4655],
[-0.7006, -0.0367, -0.9577, -0.0834, -0.7249],
[-1.9151, 2.3360, 1.1214, 0.4452, -1.1233]])
tensor([4, 0, 3, 1, 4])
tensor([0, 3, 1, 2, 0])
We can see that the minimum values present in the index along columns are:
- Min value – -1.9151. Its index is 4.
- Min value – -0.7426. Its index is 0.
- Min value – -0.9577. Its index is 3.
- Min value – -1.1608. Its index is 1.
- Min value – -1.1233. Its index is 4.
Similarly, minimum values at index along the rows are:
- Min value – -1.7387. Its index is 0.
- Min value – -1.1608. Its index is 3.
- Min value – -0.5474. Its index is 1.
- Min value – -0.9577. Its index is 2.
- Min value – -1.9151. Its index is 0.
Work with CPU
If you want to run an argmin() function on the CPU, then we have to create a tensor with a cpu() function. This will run on a CPU machine.
At this time, when we are creating a tensor, we can use the cpu() function.
Syntax:
torch.tensor(data).cpu()
Example 1:
In this example, we will create a tensor with two dimensions on the CPU that has three rows and five columns and apply argmin() on the rows and columns.
import torch
#create a tensor with 2 dimensions (3 * 5)
#with random elements using randn() with cpu() function
data = torch.randn(3,5).cpu()
#display
print(data)
#get minimum index along columns with argmin
print(torch.argmin(data, dim=0))
#get minimum index along rows with argmin
print(torch.argmin(data, dim=1))
Output:
[-1.2597, -0.3892, 0.2120, 0.1376, 0.6919],
[ 0.0449, -0.3545, -0.1914, 0.1969, -2.0053]])
tensor([1, 1, 2, 0, 2])
tensor([3, 0, 4])
As we can see, the minimum values for the indexes and columns are:
- Min value – -1.2597. Its index is 1.
- Min value – 1 -0.3892. Its index is 1.
- Min value – -0.1914. Its index is 2.
- Min value – 0.4714. Its index is 0.
- Min value – -2.0053. Its index is 2.
Similarly, the minimum values at the index along the rows are:
- Min value – -0.4714. Its index is 3.
- Min value – -1.2597. Its index is 0.
- Min value – -2.0053. Its index is 4.
Example 2:
Create a tensor with a five by five matrix on the CPU and apply argmin().
import torch
#create a tensor with 2 dimensions (5 * 5)
#with random elements using randn() function
data = torch.randn(5,5).cpu()
#display
print(data)
#get minimum index along columns with argmin
print(torch.argmin(data, dim=0))
#get minimum index along rows with argmin
print(torch.argmin(data, dim=1))
Output:
[ 0.2564, -0.3471, 1.5256, -1.1608, 0.4367],
[ 1.4390, -0.5474, 0.5909, 0.0491, 0.4655],
[-0.7006, -0.0367, -0.9577, -0.0834, -0.7249],
[-1.9151, 2.3360, 1.1214, 0.4452, -1.1233]])
tensor([4, 0, 3, 1, 4])
tensor([0, 3, 1, 2, 0])
As we can see, the minimum values for the indexes and columns are:
- Min value – -1.9151. Its index is 4.
- Min value – -0.7426. Its index is 0.
- Min value – -0.9577. Its index is 3.
- Min value – -1.1608. Its index is 1.
- Min value – -1.1233. Its index is 4.
Similarly, the minimum values at the index along the rows are:
- Min value – -1.7387. Its index is 0.
- Min value – -1.1608. Its index is 3.
- Min value – -0.5474. Its index is 1.
- Min value – -0.9577. Its index is 2.
- Min value – -1.9151. Its index is 0.
Conclusion
In this PyTorch lesson, we saw what argmin() is and how to apply argmin() to a tensor to return indices of minimum values across columns and rows.
We also created a tensor with the cpu() function and returned indices of its minimum values. dim is the parameter used to return indices of minimum values across columns when it is set to 0 and return indices of minimum values across rows when it is set to 1.
Source: linuxhint.com