import numpy as np
arr = [[[2,1,1],[2,1,2]]]
expansion = np.expand_dims(arr, axis=0)
reduction = np.squeeze(arr, axis=1)
arr.shape, expansion.shape, reduction.shape
((3, 1), (1, 3, 1), (3,))
print(arr, expansion, reduction, sep='\n\n')
[[1],
[2],
[3]][[[1],
[2],
[3]]][1, 2, 3]