Hello all,
I have a problem with my python code which is that it takes forever to run (more than 2 weeks of running). My Python script takes a particle tracking NetCDF file for a single year, and a shapefile of coral reef polygons, then builds a “connectivity” transition matrix that counts how many particles travel from each source polygon to each sink polygon each day. It also applies a probabilistic “settlement competency” filter based on a Randall survival curve from a csv file.
The issue comes from looping over the more than a million particles at each timestep and checking if its position is contained in any of my 207 reef polygons as defined by my shapefile. Any recommendations which could significantly speed up my code would be super helpful. I have included the full code at the end of this post and also as a .py file attachment if that helps. Thank you very much
This is the portion of the code which slows it down a lot:
# Loop over each timestep -- Section of the code which takes the longest
for t in range(int(Tmin), num_days * 24):
if t >= num_timesteps:
break
if t % 24 == 0:
print(f"Processing day {t // 24}", flush=True)
# Initialize a new transitions matrix for this timestep
transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)
# Loop over each particle
for i in range(num_particles):
point_current = Point(lon[i, t], lat[i, t])
if np.isnan(lon[i, t]) or np.isnan(lat[i, t]): # if lon lat is not nan continue, otherwise it's beached or out of domain
continue
probability = P_t_interpolator(t / 24) # Check probability of settlement
if np.random.rand() <= probability:
for k in idx.intersection(point_current.bounds):
if point_current.within(gdf_all.geometry[k]):
transitions_matrix[initial_polygons[i], k] += 1
break # exit once the sink polygon is found
Brief description of what the code does:
This script begins by loading
-
NetCDF file containing hourly longitude and latitude for over a million particles across 30 days (with -999 marking beached or out of domain points).
-
A shapefile containing 207 polygons of which some are multipart (and complex polygons) defining coral reef patches. These polygons form the rows and columns in the outputted connectivity (transitions) matrix.
-
CSV of a Randall competency curve (probability of settlement as a function of larval age in days).
After interpolating that curve into a continuous function, the code reads the shapefile into a GeoDataFrame, builds an R tree index for rapid point in polygon lookups, and—for each year’s NetCDF—determines the last successfully processed day so it can resume without overwriting previous results.
Within the process_file function, the script opens the NetCDF, masks all -999 values to NaN (these are the beached particles), and discards particles whose initial positions are invalid.
Also includes an optional debug to reduce particle sample size. It then assigns each surviving particle to a “source” polygon by creating a Shapely Point at time 0 and querying the R tree. Next, it steps through each hourly timestamp from the resume point up to 30 days: for each particle that still has valid coordinates, it draws a random number and, if that number falls below the interpolated Randall probability for the particle’s age in days, it again queries the R tree and checks point in polygon membership to find the “sink” polygon.
Each source–sink transition is tallied in a 207×207 matrix for that hour whereby rows in the matrix are source reefs and columns are sink reefs. After every 24 hours, the code stacks those hourly matries into a 207×207×24 array and outputs the results.
Note: I have a seperate code for each yearly run so i can run them each individually instead of sequentially.
This is the full code, thanks for the help!:
from rtree import index
import numpy as np
import time
import glob
from shapely.geometry import Polygon, Point
import csv
import pandas as pd
import xarray as xr
import geopandas as gpd
import os
from scipy.interpolate import interp1d
import random
# Load in Randall curve
curve_data = pd.read_csv('/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Randall_comp_curves/Amil_cca_curve_SB.csv')
larval_age = curve_data['LarvalAge'].values
P_t = curve_data['P_t'].values
P_t_interpolator = interp1d(larval_age, P_t, fill_value="extrapolate")
def process_file(file_path, gdf_all, idx, Tmin, num_days=30, num_hours_per_day=24):
print(f"Starting from Tmin Yobama {Tmin}", flush=True)
data_xarray = xr.open_dataset(file_path, mode='r')
lon = data_xarray['lon'].values # i have also tried .astype(np.float32) but not significantly faster
lat = data_xarray['lat'].values
lon[lon == -999] = np.nan # set beached particles to NAN
lat[lat == -999] = np.nan # set beached particles to NAN
print(f"NaN count at t=0: {np.isnan(lon[:, 0]).sum()}, {np.isnan(lat[:, 0]).sum()}", flush=True)
print(f"-999 count at t=0: {(lon[:, 0] == -999).sum()}, {(lat[:, 0] == -999).sum()}", flush=True)
num_particles, num_timesteps = lon.shape
valid_particles = ~np.isnan(lon[:, 0]) & ~np.isnan(lat[:, 0]) & (lon[:, 0] != -999) & (lat[:, 0] != -999)
lon = lon[valid_particles, :]
lat = lat[valid_particles, :]
num_particles, num_timesteps = lon.shape
##############################################################################
# Select only 10000 particles for debugging
num_particles_to_debug = 10000
if num_particles > num_particles_to_debug:
sampled_indices = np.random.choice(num_particles, num_particles_to_debug, replace=False)
lon = lon[sampled_indices, :]
lat = lat[sampled_indices, :]
print(f"Debugging with {lon.shape[0]} particles", flush=True)
num_particles, num_timesteps = lon.shape
print(f"num of particles {num_particles}", flush=True)
###########################################################################
num_polygons = len(gdf_all)
transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)
gdf_all = gdf_all.sort_values(by='id')
print(f"Finding initial polys", flush=True)
initial_polygons = []
for i in range(num_particles):
point_T0 = Point(lon[i, 0], lat[i, 0])
initial_polygon = next((j for j in idx.intersection(point_T0.bounds) if point_T0.intersects(gdf_all.geometry[j])), None)
initial_polygons.append(initial_polygon)
print(f"Valid particles at t=0 w initial polygons: {num_particles}", flush=True)
# Initialize the 3D numpy array to hold the transition matrices for each day
all_daily_matrices = []
# Loop over each timestep -- Section of the code which takes the longest
for t in range(int(Tmin), num_days * 24):
if t >= num_timesteps:
break
if t % 24 == 0:
print(f"Processing day {t // 24}", flush=True)
# Initialize a new transitions matrix for this timestep
transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)
# Loop over each particle
for i in range(num_particles):
point_current = Point(lon[i, t], lat[i, t])
if np.isnan(lon[i, t]) or np.isnan(lat[i, t]): # if lon lat is not nan continue, otherwise it's beached or out of domain
continue
probability = P_t_interpolator(t / 24) # Check probability of settlement
if np.random.rand() <= probability:
for k in idx.intersection(point_current.bounds):
if point_current.within(gdf_all.geometry[k]):
transitions_matrix[initial_polygons[i], k] += 1
break # exit once the sink polygon is found
# Check if the transitions_matrix has a consistent shape
print(f"Shape of transitions_matrix for day {t//24}: {transitions_matrix.shape}")
# Append the transitions matrix for this timestep to the daily 3D array
all_daily_matrices.append(transitions_matrix)
# Once all 24 timesteps for a day are processed, combine into a 3D array and output
if (t + 1) % 24 == 0: # End of the day
daily_array = np.stack(all_daily_matrices, axis=2)
print(f"Shape of daily_array for day {t//24}: {daily_array.shape}")
# Save the daily 3D numpy array
daily_filename = f"{pathout}/transitions_matrix_day{t//24}_{year}.npy"
np.save(daily_filename, daily_array)
print(f"Saved daily transitions matrix for day {t//24} to {daily_filename}", flush=True)
# Reset for the next day
all_daily_matrices = []
return transitions_matrix, num_timesteps
# Get most recent Tmin so can restart the run from the last outputted day
def get_most_recent_Tmin(pathout, year):
# Find all the transition matrix files for the given year
file_pattern = os.path.join(pathout, f'transitions_matrix_day*_{year}.npy')
files = glob.glob(file_pattern)
if not files:
return 0 # If no files are found, start from day 0
# Sort the files based on the day extracted from the filename
files.sort(key=lambda x: int(x.split('day')[1].split('_')[0])) # Extract day number and sort
# Get the most recent file and determine the day (Tmin)
most_recent_file = files[-1]
most_recent_day = int(most_recent_file.split('day')[1].split('_')[0])
# Set Tmin as the most recent day + 1 (to restart from the next day)
Tmin = (most_recent_day + 1)*24
print(f"Most recent file: {most_recent_file}, Tmin set to: {Tmin}")
return Tmin
# Main to run the code
if __name__ == "__main__":
# Path to netcdf particle tracking files
base_path = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/SensitivityAnalysis/pout/'
pathout = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/conmat/'
# just running for a single year (i have 30 years to run!)
years = range(1996,1997)
# Reef polygon shapefile
gdf_all = gpd.read_file("/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Coral_communities/Regions_w_geom_FINAL_crop.shp")
gdf_all = gdf_all.sort_values(by='id')
idx = index.Index((j, geom.bounds, None) for j, geom in enumerate(gdf_all.geometry))
# Process each file sequentially
for year in years:
Tmin = get_most_recent_Tmin(pathout, year)
print(f"Using Tmin: {Tmin} for year {year}", flush = True)
file_path = f'{base_path}/ParcelsOut_Diffusion_{year}.nc'
transitions_matrix, num_timesteps = process_file(file_path, gdf_all, idx, Tmin=Tmin) # Capture both outputs
print(f"All transition matrices saved individually to {pathout}.")