NumPy - expand_dims() function
The NumPy expand_dims() function expands the shape of an array. It inserts a new axis that will appear at the axis position in the expanded array shape.
Syntax
numpy.expand_dims(a, axis)
Parameters
a |
Required. Specify the input array (array_like). |
axis |
Required. Specify the position in the expanded axes where the new axis (or axes) is placed. It can be int or tuple of ints. |
Return Value
Returns view of a with the number of dimensions increased.
Example:
In the example below, an array is expanded on a given axis.
import numpy as np x = np.array([1, 2, 3]) #expanding the dimension of x on axis=0 x1 = np.expand_dims(x, axis=0) #expanding the dimension of x on axis=1 x2 = np.expand_dims(x, axis=1) #expanding the dimension of x on axis=(0,1) x3 = np.expand_dims(x, axis=(0,1)) #displaying results print("shape of x:", x.shape) print("x contains:") print(x) print("\nshape of x1:", x1.shape) print("x1 contains:") print(x1) print("\nshape of x2:", x2.shape) print("x2 contains:") print(x2) print("\nshape of x3:", x3.shape) print("x3 contains:") print(x3)
The output of the above code will be:
shape of x: (3,) x contains: [1 2 3] shape of x1: (1, 3) x1 contains: [[1 2 3]] shape of x2: (3, 1) x2 contains: [[1] [2] [3]] shape of x3: (1, 1, 3) x3 contains: [[[1 2 3]]]
❮ NumPy - Functions