Geopotential Heights Composites

I am trying to plot subplots for a composite analysis. I have over 100 case studies. Since I am new in writing python coding, can someone rewrite my code to simplify it as I am having trouble with a for loop of axs. Any help would be grateful.

 import xarray as xr
import numpy as np
from datetime import datetime
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.patches as mpatches
import pandas as pd
from geospatial_utils import area_of_interest

ds = xr.open_dataset('February_1981.nc')
dp = xr.open_dataset('Geopot_1982_2008.nc')
ds.merge(dp)
print(ds)

proj = ccrs.LambertConformal(central_longitude=-110.0, central_latitude=35.0)

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection=proj)

ax.stock_img()
ax.add_patch(mpatches.Rectangle(xy=[-170, 20], width=120, height=45, facecolor='red', alpha=0.2, transform=ccrs.PlateCarree()))
ax.set_extent([-180, -30, 10, 70], crs=ccrs.PlateCarree())
ax.gridlines()
ax.coastlines()

plt.show()

time = ds['valid_time'].values # numpy datetime64 object
lvl = ds['pressure_level'].values # units of hPa
lat = ds['latitude'].values # units of degrees (postitive is northern hemisphere)
lon = ds['longitude'].values # units of degrees (negative is western hemisphere)
time_ind = pd.date_range(start = '1981-02-09T00:00.00', end = '1981-02-10T00:00.00', freq = '1h')
time_ind_2 = pd.date_range(start = '1981-02-10T00:00.00', end = '1981-02-11T00:00.00', freq = '1h')
time_ind_3 = pd.date_range(start = '1981-02-11T00:00.00', end = '1981-02-11T18:00.00', freq = '1h')
time_ind_4 = pd.date_range(start = '1982-04-18T00:00.00', end = '1982-04-19T00:00.00', freq = '1h')
time_ind_5 = pd.date_range(start = '1982-04-19T00:00.00', end = '1982-04-20T00:00.00', freq = '1h')
time_ind_6 = pd.date_range(start = '1982-04-20T00:00.00', end = '1982-04-21T18:00.00', freq = '1h')

geopot = ds.sel(valid_time=time_ind)['z'].mean(dim='valid_time').values # units of m**2 / s**2
geopot_2 = ds.sel(valid_time=time_ind_2)['z'].mean(dim='valid_time').values # units of m**2 / s**2
geopot_3 = ds.sel(valid_time=time_ind_3)['z'].mean(dim='valid_time').values # units of m**2 / s**2
geopot_4 = ds.sel(valid_time=time_ind_4)['z'].mean(dim='valid_time').values
geopot_5 = ds.sel(valid_time=time_ind_5)['z'].mean(dim='valid_time').values
geopot_6 = ds.sel(valid_time=time_ind_6)['z'].mean(dim='valid_time').values

# Compute the geopotential height [m] from the geopotential variable
geopot_hght = geopot / 9.81
geopot_hght_2 = geopot_2 / 9.81
geopot_hght_3 = geopot_3 / 9.81
geopot_hght_4 = geopot_4 / 9.81
geopot_hght_5 = geopot_5 / 9.81
geopot_hght_6 = geopot_6 / 9.81

print(time.shape, lvl.shape, lat.shape, lon.shape, geopot_hght.shape)

# Obtain the time index that corresponds to xx UTC of a specific date
#time_ind = np.where(time==np.datetime64('1981-02-10T18:00:00.00'))[0][0]

# Obtain the pressure level index that corresponds to 500 hPa
lvl_ind = np.where(lvl==500.)[0][0]

geopot_hght_trimmed = geopot_hght[lvl_ind, :, :]
geopot_hght_trimmed_2 = geopot_hght_2[lvl_ind, :, :]
geopot_hght_trimmed_3 = geopot_hght_3[lvl_ind, :, :]
geopot_hght_trimmed_4 = geopot_hght_4[lvl_ind, :, :]
geopot_hght_trimmed_5 = geopot_hght_5[lvl_ind, :, :]
geopot_hght_trimmed_6 = geopot_hght_6[lvl_ind, :, :]

[X, Y] = np.meshgrid(lon, lat)
print(X.shape, Y.shape)

proj = ccrs.Mercator() # You can use other projections, just make sure that PlateCarree is used in any transform arguments below.

fig, axs = plt.subplots(3,1, figsize=(16, 12), subplot_kw = {'projection': ccrs.AlbersEqualArea(central_longitude=-154, central_latitude=50, standard_parallels=(55, 65))})

for i, ax in enumerate(axs.flat):
    

   plt = ax[i].contourf(X, Y, geopot_hght_trimmed/10., levels=np.linspace(480, 600, 41), cmap='nipy_spectral', transform=ccrs.PlateCarree())
   plt1 = ax[i+1].contourf(X, Y, geopot_hght_trimmed/10., levels=np.linspace(480, 600, 41), cmap='nipy_spectral', transform=ccrs.PlateCarree())
   plt2 = ax[i+2].contourf(X, Y, geopot_hght_trimmed/10., levels=np.linspace(480, 600, 41), cmap='nipy_spectral', transform=ccrs.PlateCarree())
   ax[i].set_boundary(area_of_interest(axs[i], west = -170, east = -120, north = 75, south = 45))
   ax[i+1].set_boundary(area_of_interest(axs[i+1], west = -170, east = -120, north = 75, south = 45))
   ax[i+2].set_boundary(area_of_interest(axs[i+2], west = -170, east = -120, north = 75, south = 45))
   ax[i].clabel(plt, np.linspace(480, 600, 11), inline=True, fmt='%d', fontsize=14)
   axs[i+1].clabel(plt1, np.linspace(480, 600, 11), inline=True, fmt='%d', fontsize=14)
   axs[i+2].clabel(plt2, np.linspace(480, 600, 11), inline=True, fmt='%d', fontsize=14)
   axs[i].set_title('500-hPa Composite Geopotential Height [dm]\n 9 February 1981', fontsize=16)
   axs[i+1].set_title('500-hPa Composite Geopotential Height [dm]\n 10 February 1981', fontsize=16)
   axs[i+2].set_title('500-hPa Composite Geopotential Height [dm]\n 10 February 1981', fontsize=16)
   states = cfeature.NaturalEarthFeature(category='cultural', name='admin_1_states_provinces_lines', scale='50m', facecolor='none')
   axs[i].add_feature(cfeature.LAND)
   axs[i+1].add_feature(cfeature.LAND)
   axs[i+2].add_feature(cfeature.LAND)
   axs[i].add_feature(cfeature.LAKES, facecolor='none', edgecolor='black')
   axs[i+1].add_feature(cfeature.LAKES, facecolor='none', edgecolor='black')
   axs[i+2].add_feature(cfeature.LAKES, facecolor='none', edgecolor='black')
   axs[i].add_feature(cfeature.COASTLINE)
   axs[i+1].add_feature(cfeature.COASTLINE)
   axs[i+2].add_feature(cfeature.COASTLINE)
   axs[i].add_feature(cfeature.BORDERS)
   axs[i+1].add_feature(cfeature.BORDERS)
   axs[i+2].add_feature(cfeature.BORDERS)
   axs[i].add_feature(states, edgecolor='black')
   axs[i+1].add_feature(states, edgecolor='black')
   axs[i+2].add_feature(states, edgecolor='black')

   # Format the gridlines (optional)
   gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, linewidth=2, color='gray', alpha=0.5, linestyle='--')
   gl.xlabels_top = False
   gl.ylabels_right = False
   gl.xlines = True; gl.xlocator = mticker.FixedLocator(np.linspace(-170, -50, 13)); gl.xformatter = LONGITUDE_FORMATTER; gl.xlabel_style = {'size':16, 'color':'black'}
   gl.ylines = True; gl.ylocator = mticker.FixedLocator(np.linspace(20, 65, 10)); gl.yformatter = LATITUDE_FORMATTER; gl.ylabel_style = {'size':16, 'color':'black'}

# Plot the colorbar
   cbar_ax = fig.add_axes([0, 0, 0.1, 0.1]) # Dummy values prior to finetuning the cbar position
   pos = ax.get_position() # Get the axes position
   cbar_ax.set_position([pos.x0 + pos.width + 0.01, pos.y0, 0.04, pos.height])
   cbar = plt.colorbar(plt1, cax=cbar_ax)
   cbar_ax.tick_params(labelsize=16)

   
   ax.set_title(f'500-hPa Geopotential Height {i+1}')

plt.tight_layout()
plt.show()