I am trying to plot subplots for a composite analysis. I have over 100 case studies. Since I am new in writing python coding, can someone rewrite my code to simplify it as I am having trouble with a for loop of axs. Any help would be grateful.
import xarray as xr
import numpy as np
from datetime import datetime
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.patches as mpatches
import pandas as pd
from geospatial_utils import area_of_interest
ds = xr.open_dataset('February_1981.nc')
dp = xr.open_dataset('Geopot_1982_2008.nc')
ds.merge(dp)
print(ds)
proj = ccrs.LambertConformal(central_longitude=-110.0, central_latitude=35.0)
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection=proj)
ax.stock_img()
ax.add_patch(mpatches.Rectangle(xy=[-170, 20], width=120, height=45, facecolor='red', alpha=0.2, transform=ccrs.PlateCarree()))
ax.set_extent([-180, -30, 10, 70], crs=ccrs.PlateCarree())
ax.gridlines()
ax.coastlines()
plt.show()
time = ds['valid_time'].values # numpy datetime64 object
lvl = ds['pressure_level'].values # units of hPa
lat = ds['latitude'].values # units of degrees (postitive is northern hemisphere)
lon = ds['longitude'].values # units of degrees (negative is western hemisphere)
time_ind = pd.date_range(start = '1981-02-09T00:00.00', end = '1981-02-10T00:00.00', freq = '1h')
time_ind_2 = pd.date_range(start = '1981-02-10T00:00.00', end = '1981-02-11T00:00.00', freq = '1h')
time_ind_3 = pd.date_range(start = '1981-02-11T00:00.00', end = '1981-02-11T18:00.00', freq = '1h')
time_ind_4 = pd.date_range(start = '1982-04-18T00:00.00', end = '1982-04-19T00:00.00', freq = '1h')
time_ind_5 = pd.date_range(start = '1982-04-19T00:00.00', end = '1982-04-20T00:00.00', freq = '1h')
time_ind_6 = pd.date_range(start = '1982-04-20T00:00.00', end = '1982-04-21T18:00.00', freq = '1h')
geopot = ds.sel(valid_time=time_ind)['z'].mean(dim='valid_time').values # units of m**2 / s**2
geopot_2 = ds.sel(valid_time=time_ind_2)['z'].mean(dim='valid_time').values # units of m**2 / s**2
geopot_3 = ds.sel(valid_time=time_ind_3)['z'].mean(dim='valid_time').values # units of m**2 / s**2
geopot_4 = ds.sel(valid_time=time_ind_4)['z'].mean(dim='valid_time').values
geopot_5 = ds.sel(valid_time=time_ind_5)['z'].mean(dim='valid_time').values
geopot_6 = ds.sel(valid_time=time_ind_6)['z'].mean(dim='valid_time').values
# Compute the geopotential height [m] from the geopotential variable
geopot_hght = geopot / 9.81
geopot_hght_2 = geopot_2 / 9.81
geopot_hght_3 = geopot_3 / 9.81
geopot_hght_4 = geopot_4 / 9.81
geopot_hght_5 = geopot_5 / 9.81
geopot_hght_6 = geopot_6 / 9.81
print(time.shape, lvl.shape, lat.shape, lon.shape, geopot_hght.shape)
# Obtain the time index that corresponds to xx UTC of a specific date
#time_ind = np.where(time==np.datetime64('1981-02-10T18:00:00.00'))[0][0]
# Obtain the pressure level index that corresponds to 500 hPa
lvl_ind = np.where(lvl==500.)[0][0]
geopot_hght_trimmed = geopot_hght[lvl_ind, :, :]
geopot_hght_trimmed_2 = geopot_hght_2[lvl_ind, :, :]
geopot_hght_trimmed_3 = geopot_hght_3[lvl_ind, :, :]
geopot_hght_trimmed_4 = geopot_hght_4[lvl_ind, :, :]
geopot_hght_trimmed_5 = geopot_hght_5[lvl_ind, :, :]
geopot_hght_trimmed_6 = geopot_hght_6[lvl_ind, :, :]
[X, Y] = np.meshgrid(lon, lat)
print(X.shape, Y.shape)
proj = ccrs.Mercator() # You can use other projections, just make sure that PlateCarree is used in any transform arguments below.
fig, axs = plt.subplots(3,1, figsize=(16, 12), subplot_kw = {'projection': ccrs.AlbersEqualArea(central_longitude=-154, central_latitude=50, standard_parallels=(55, 65))})
for i, ax in enumerate(axs.flat):
plt = ax[i].contourf(X, Y, geopot_hght_trimmed/10., levels=np.linspace(480, 600, 41), cmap='nipy_spectral', transform=ccrs.PlateCarree())
plt1 = ax[i+1].contourf(X, Y, geopot_hght_trimmed/10., levels=np.linspace(480, 600, 41), cmap='nipy_spectral', transform=ccrs.PlateCarree())
plt2 = ax[i+2].contourf(X, Y, geopot_hght_trimmed/10., levels=np.linspace(480, 600, 41), cmap='nipy_spectral', transform=ccrs.PlateCarree())
ax[i].set_boundary(area_of_interest(axs[i], west = -170, east = -120, north = 75, south = 45))
ax[i+1].set_boundary(area_of_interest(axs[i+1], west = -170, east = -120, north = 75, south = 45))
ax[i+2].set_boundary(area_of_interest(axs[i+2], west = -170, east = -120, north = 75, south = 45))
ax[i].clabel(plt, np.linspace(480, 600, 11), inline=True, fmt='%d', fontsize=14)
axs[i+1].clabel(plt1, np.linspace(480, 600, 11), inline=True, fmt='%d', fontsize=14)
axs[i+2].clabel(plt2, np.linspace(480, 600, 11), inline=True, fmt='%d', fontsize=14)
axs[i].set_title('500-hPa Composite Geopotential Height [dm]\n 9 February 1981', fontsize=16)
axs[i+1].set_title('500-hPa Composite Geopotential Height [dm]\n 10 February 1981', fontsize=16)
axs[i+2].set_title('500-hPa Composite Geopotential Height [dm]\n 10 February 1981', fontsize=16)
states = cfeature.NaturalEarthFeature(category='cultural', name='admin_1_states_provinces_lines', scale='50m', facecolor='none')
axs[i].add_feature(cfeature.LAND)
axs[i+1].add_feature(cfeature.LAND)
axs[i+2].add_feature(cfeature.LAND)
axs[i].add_feature(cfeature.LAKES, facecolor='none', edgecolor='black')
axs[i+1].add_feature(cfeature.LAKES, facecolor='none', edgecolor='black')
axs[i+2].add_feature(cfeature.LAKES, facecolor='none', edgecolor='black')
axs[i].add_feature(cfeature.COASTLINE)
axs[i+1].add_feature(cfeature.COASTLINE)
axs[i+2].add_feature(cfeature.COASTLINE)
axs[i].add_feature(cfeature.BORDERS)
axs[i+1].add_feature(cfeature.BORDERS)
axs[i+2].add_feature(cfeature.BORDERS)
axs[i].add_feature(states, edgecolor='black')
axs[i+1].add_feature(states, edgecolor='black')
axs[i+2].add_feature(states, edgecolor='black')
# Format the gridlines (optional)
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, linewidth=2, color='gray', alpha=0.5, linestyle='--')
gl.xlabels_top = False
gl.ylabels_right = False
gl.xlines = True; gl.xlocator = mticker.FixedLocator(np.linspace(-170, -50, 13)); gl.xformatter = LONGITUDE_FORMATTER; gl.xlabel_style = {'size':16, 'color':'black'}
gl.ylines = True; gl.ylocator = mticker.FixedLocator(np.linspace(20, 65, 10)); gl.yformatter = LATITUDE_FORMATTER; gl.ylabel_style = {'size':16, 'color':'black'}
# Plot the colorbar
cbar_ax = fig.add_axes([0, 0, 0.1, 0.1]) # Dummy values prior to finetuning the cbar position
pos = ax.get_position() # Get the axes position
cbar_ax.set_position([pos.x0 + pos.width + 0.01, pos.y0, 0.04, pos.height])
cbar = plt.colorbar(plt1, cax=cbar_ax)
cbar_ax.tick_params(labelsize=16)
ax.set_title(f'500-hPa Geopotential Height {i+1}')
plt.tight_layout()
plt.show()