Cannot understand the index organization of numpy.argsort for a 4 Dimensional array

Hello everyone!

I’m beginner in Numpy and I struggle to understand how to interpret correctly the result of call of numpy.argsort while dealing with a 4D array.

Here is my test case to reproduce my problem:


import numpy as np
a = np.array([[[[ -28,   83,   17,  132,   37],
         [ -65,  180,  132,   88,  -71],
         [  48,  128, -101,   50,  103],
         [ 125,   24, -191,   -9,  -53]],

        [[  69, -171,  158,  -88, -166],
         [  12,   68,  193,    6,  180],
         [  85,  -93,  -81,  -97,  -99],
         [ -87, -137,   68,  116, -121]]],


       [[[ 147, -141,  160, -156,   27],
         [ 133,  144,  119,  129,  -28],
         [ 150,  -50,  180, -176,  -50],
         [   4,   53,  154,   42,   77]],

        [[  78,    5,  177,  -53, -177],
         [-117,  -74,  -89,   29,  158],
         [-179, -165,  143,   42,  -89],
         [  11, -195, -151,  174,   71]]],


       [[[  56,  152,  -12,  170, -155],
         [-127,  163,  176,    6,  165],
         [  50,   15,  -28,   16, -150],
         [ 117,  162,  187,  -97, -131]],

        [[-156,  135,   37,  -11,   80],
         [  72,   63, -126,  -75,  111],
         [ -59,  174,  -58,   44, -193],
         [-166,  170,  -84, -149,  123]]]])

I checked the result of np.sort(a, axis=0) and np.argsort(a, axis=0), and I gathered the result in the following table so that I might be able to compare them and understand better how each of them works:


-------------------------------------+-----------------------+-------------------------------------
                  a                  | np.argsort(a, axis=0) |        np.sort(a, axis=0)
-------------------------------------+-----------------------+-------------------------------------
[[[[ -28,   83,   17,  132,   37],   | [[[[0, 1, 2, 1, 2],   | [[[[ -28, -141,  -12, -156, -155],   
   [ -65,  180,  132,   88,  -71],   |    [2, 1, 1, 2, 0],   |    [-127,  144,  119,    6,  -71],   
   [  48,  128, -101,   50,  103],   |    [0, 1, 0, 1, 2],   |    [  48,  -50, -101, -176, -150],   
   [ 125,   24, -191,   -9,  -53]],  |    [1, 0, 0, 2, 2]],  |    [   4,   24, -191,  -97, -131]],  
                                     |                       |                                      
  [[  69, -171,  158,  -88, -166],   |   [[2, 0, 2, 0, 1],   |   [[-156, -171,   37,  -88, -177],   
   [  12,   68,  193,    6,  180],   |    [1, 1, 2, 2, 2],   |    [-117,  -74, -126,  -75,  111],   
   [  85,  -93,  -81,  -97,  -99],   |    [1, 1, 0, 0, 2],   |    [-179, -165,  -81,  -97, -193],   
   [ -87, -137,   68,  116, -121]]], |    [2, 1, 1, 2, 0]]], |    [-166, -195, -151, -149, -121]]], 
                                     |                       |                                      
                                     |                       |                                      
 [[[ 147, -141,  160, -156,   27],   |  [[[2, 0, 0, 0, 1],   |  [[[  56,   83,   17,  132,   27],   
   [ 133,  144,  119,  129,  -28],   |    [0, 2, 0, 0, 1],   |    [ -65,  163,  132,   88,  -28],   
   [ 150,  -50,  180, -176,  -50],   |    [2, 2, 2, 2, 1],   |    [  50,   15,  -28,   16,  -50],   
   [   4,   53,  154,   42,   77]],  |    [2, 1, 1, 0, 0]],  |    [ 117,   53,  154,   -9,  -53]],  
                                     |                       |                                      
  [[  78,    5,  177,  -53, -177],   |   [[0, 1, 0, 1, 0],   |   [[  69,    5,  158,  -53, -166],   
   [-117,  -74,  -89,   29,  158],   |    [0, 2, 1, 0, 1],   |    [  12,   63,  -89,    6,  158],   
   [-179, -165,  143,   42,  -89],   |    [2, 0, 2, 1, 0],   |    [ -59,  -93,  -58,   42,  -99],   
   [  11, -195, -151,  174,   71]]], |    [0, 0, 2, 0, 1]]], |    [ -87, -137,  -84,  116,   71]]], 
                                     |                       |                                      
                                     |                       |                                      
 [[[  56,  152,  -12,  170, -155],   |  [[[1, 2, 1, 2, 0],   |  [[[ 147,  152,  160,  170,   37],   
   [-127,  163,  176,    6,  165],   |    [1, 0, 2, 1, 2],   |    [ 133,  180,  176,  129,  165],   
   [  50,   15,  -28,   16, -150],   |    [1, 0, 1, 0, 0],   |    [ 150,  128,  180,   50,  103],   
   [ 117,  162,  187,  -97, -131]],  |    [0, 2, 2, 1, 1]],  |    [ 125,  162,  187,   42,   77]],  
                                     |                       |                                      
  [[-156,  135,   37,  -11,   80],   |   [[1, 2, 1, 2, 2],   |   [[  78,  135,  177,  -11,   80],   
   [  72,   63, -126,  -75,  111],   |    [2, 0, 0, 1, 0],   |    [  72,   68,  193,   29,  180],   
   [ -59,  174,  -58,   44, -193],   |    [0, 2, 1, 2, 1],   |    [  85,  174,  143,   44,  -89],   
   [-166,  170,  -84, -149,  123]]]] |    [1, 2, 0, 1, 2]]]] |    [  11,  170,   68,  174,  123]]]]

Unfortunately, I have not been able to find the mapping between indices returned by np.argsort and the original array. In other words, when I look to the indices in the 2nd column of the above-mentioned table, I don’t see how applying any of them (I mean using them as index) to the origian array 'a' can produce the sorted result in the 3rd column of the above-mentioned table.

I know that the table is long, so even if you could kindly explain to me the first following portion at the beginning of the table, that could help:


-------------------------------------+-----------------------+-------------------------------------
                  a                  | np.argsort(a, axis=0) |        np.sort(a, axis=0)
-------------------------------------+-----------------------+-------------------------------------
[[[[ -28,   83,   17,  132,   37],   | [[[[0, 1, 2, 1, 2],   | [[[[ -28, -141,  -12, -156, -155],   
   [ -65,  180,  132,   88,  -71],   |    [2, 1, 1, 2, 0],   |    [-127,  144,  119,    6,  -71],   
   [  48,  128, -101,   50,  103],   |    [0, 1, 0, 1, 2],   |    [  48,  -50, -101, -176, -150],   
   [ 125,   24, -191,   -9,  -53]],  |    [1, 0, 0, 2, 2]],  |    [   4,   24, -191,  -97, -131]],  
                  .                              .                               .
				  .                              .                               .
				  .                              .                               .

Thanks in advance

From this discussion on SO python - How to use numpy.argsort() as indices in more than 2 dimensions? - Stack Overflow, you are looking for numpy.take_along_axis numpy.take_along_axis — NumPy v2.2 Manual maybe?

import numpy as np

if __name__ == "__main__":
    a = np.array(
        [
            [
                [
                    [-28, 83, 17, 132, 37],
                    [-65, 180, 132, 88, -71],
                    [48, 128, -101, 50, 103],
                    [125, 24, -191, -9, -53],
                ],
                [
                    [69, -171, 158, -88, -166],
                    [12, 68, 193, 6, 180],
                    [85, -93, -81, -97, -99],
                    [-87, -137, 68, 116, -121],
                ],
            ],
            [
                [
                    [147, -141, 160, -156, 27],
                    [133, 144, 119, 129, -28],
                    [150, -50, 180, -176, -50],
                    [4, 53, 154, 42, 77],
                ],
                [
                    [78, 5, 177, -53, -177],
                    [-117, -74, -89, 29, 158],
                    [-179, -165, 143, 42, -89],
                    [11, -195, -151, 174, 71],
                ],
            ],
            [
                [
                    [56, 152, -12, 170, -155],
                    [-127, 163, 176, 6, 165],
                    [50, 15, -28, 16, -150],
                    [117, 162, 187, -97, -131],
                ],
                [
                    [-156, 135, 37, -11, 80],
                    [72, 63, -126, -75, 111],
                    [-59, 174, -58, 44, -193],
                    [-166, 170, -84, -149, 123],
                ],
            ],
        ]
    )

    idx = np.argsort(a, axis=0)
    print(np.all(np.take_along_axis(a, idx, axis=0) == np.sort(a, axis=0)))

As to how numpy.take_along_axis works, I am not sure :sweat_smile:, but at least that should answer how to get the same array as using numpy.sort.

1 Like

Thank you very much for your help. After a lot of Googling and also discussions with ChatGPT, my understanding is that one has to look at this organization of elements, rather like a coordinate system with N dimensions (unlike Cartesian system which has only 2 dimensions) where a point (= an element in the array) is defined like (x_0, x_1, x_2, … x_i, …, x_N-1) for all i in {1, 2, …, N-1}. In that case, sorting along a specific axis, means to increment indices along that specific axis, in a way that in each iteration, elements in corresponding positions (= elements that would be compared with each other during the sort operation):

  1. Have exactly the same value on each pair of non-sorting axis (of course, that doesn’t mean that they have necessarily the same value, on all not sorting axis. All that matters, is to have the same value on each pair of non sorting axis, when we compare each pair of them)

  2. And they differ only on their index value of the sorting axis.

By this logic, it makes perfectly sense and if we read indices of the argsort, column by column from the top to the bottom, each of those indices are indeed the index of the smallest (according to the sort order) element (again, very important to note: smallest only in the context of corresponding positions) picked during the sort operation.

Just as an example, looking to the array elements in my above-mentioned post, using this method for analyzing elements, assuming that we sort along axis=1, we can say that -28 and 69 are in corresponding positions (= they will be compared with each other during the sort operation)

>>> a[0, :, 0, 0]    
array([-28,  69])
>>>
>>>
>>> # We can also look at them separately
>>> a[0, 0, 0, 0]
np.int64(-28)
>>>
>>>
>>> a[0, 1, 0, 0]
np.int64(69)

Now, how exactly take_along_axis figures out intelligently the indices according to the initial shape, I’ve no idea! I tried to look a bit on Github into the source code of NumPy, but it seems to me that it is far to be just a simple small function and a lot of imported modules are implicated. So far, what I have understood is that unless the sort operation and the original array are 1-D arrays, one has to use take_along_axis to obtain correctly the initial structure. Thaks again for your help.