Make code run faster: point within polygon lookups

Hello all,
I have a problem with my python code which is that it takes forever to run (more than 2 weeks of running). My Python script takes a particle tracking NetCDF file for a single year, and a shapefile of coral reef polygons, then builds a “connectivity” transition matrix that counts how many particles travel from each source polygon to each sink polygon each day. It also applies a probabilistic “settlement competency” filter based on a Randall survival curve from a csv file.

The issue comes from looping over the more than a million particles at each timestep and checking if its position is contained in any of my 207 reef polygons as defined by my shapefile. Any recommendations which could significantly speed up my code would be super helpful. I have included the full code at the end of this post and also as a .py file attachment if that helps. Thank you very much :slight_smile:

This is the portion of the code which slows it down a lot:

# Loop over each timestep -- Section of the code which takes the longest
    for t in range(int(Tmin), num_days * 24):
        if t >= num_timesteps:
            break
        
        if t % 24 == 0:
            print(f"Processing day {t // 24}", flush=True)

        # Initialize a new transitions matrix for this timestep
        transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)      

        # Loop over each particle        
        for i in range(num_particles):
            point_current = Point(lon[i, t], lat[i, t])
            
            if np.isnan(lon[i, t]) or np.isnan(lat[i, t]):  # if lon lat is not nan continue, otherwise it's beached or out of domain
                continue
 

            probability = P_t_interpolator(t / 24) # Check probability of settlement
            
            if np.random.rand() <= probability:
                for k in idx.intersection(point_current.bounds):
                    if point_current.within(gdf_all.geometry[k]):
                        transitions_matrix[initial_polygons[i], k] += 1
                        break  # exit once the sink polygon is found

Brief description of what the code does:

This script begins by loading

  1. NetCDF file containing hourly longitude and latitude for over a million particles across 30 days (with -999 marking beached or out of domain points).

  2. A shapefile containing 207 polygons of which some are multipart (and complex polygons) defining coral reef patches. These polygons form the rows and columns in the outputted connectivity (transitions) matrix.

  3. CSV of a Randall competency curve (probability of settlement as a function of larval age in days).

After interpolating that curve into a continuous function, the code reads the shapefile into a GeoDataFrame, builds an R tree index for rapid point in polygon lookups, and—for each year’s NetCDF—determines the last successfully processed day so it can resume without overwriting previous results.

Within the process_file function, the script opens the NetCDF, masks all -999 values to NaN (these are the beached particles), and discards particles whose initial positions are invalid.
Also includes an optional debug to reduce particle sample size. It then assigns each surviving particle to a “source” polygon by creating a Shapely Point at time 0 and querying the R tree. Next, it steps through each hourly timestamp from the resume point up to 30 days: for each particle that still has valid coordinates, it draws a random number and, if that number falls below the interpolated Randall probability for the particle’s age in days, it again queries the R tree and checks point in polygon membership to find the “sink” polygon.
Each source–sink transition is tallied in a 207×207 matrix for that hour whereby rows in the matrix are source reefs and columns are sink reefs. After every 24 hours, the code stacks those hourly matries into a 207×207×24 array and outputs the results.

Note: I have a seperate code for each yearly run so i can run them each individually instead of sequentially.

This is the full code, thanks for the help!:

from rtree import index
import numpy as np
import time
import glob
from shapely.geometry import Polygon, Point
import csv
import pandas as pd
import xarray as xr
import geopandas as gpd
import os
from scipy.interpolate import interp1d
import random

# Load in Randall curve
curve_data = pd.read_csv('/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Randall_comp_curves/Amil_cca_curve_SB.csv')
larval_age = curve_data['LarvalAge'].values
P_t = curve_data['P_t'].values
P_t_interpolator = interp1d(larval_age, P_t, fill_value="extrapolate")


def process_file(file_path, gdf_all, idx, Tmin, num_days=30, num_hours_per_day=24):
    print(f"Starting from Tmin Yobama {Tmin}", flush=True)
    data_xarray = xr.open_dataset(file_path, mode='r')
    lon = data_xarray['lon'].values # i have also tried .astype(np.float32) but not significantly faster
    lat = data_xarray['lat'].values

    lon[lon == -999] = np.nan # set beached particles to NAN
    lat[lat == -999] = np.nan # set beached particles to NAN
    
    print(f"NaN count at t=0: {np.isnan(lon[:, 0]).sum()}, {np.isnan(lat[:, 0]).sum()}", flush=True)
    print(f"-999 count at t=0: {(lon[:, 0] == -999).sum()}, {(lat[:, 0] == -999).sum()}", flush=True)

    num_particles, num_timesteps = lon.shape
        
    valid_particles = ~np.isnan(lon[:, 0]) & ~np.isnan(lat[:, 0]) & (lon[:, 0] != -999) & (lat[:, 0] != -999)
    lon = lon[valid_particles, :]
    lat = lat[valid_particles, :]
    num_particles, num_timesteps = lon.shape

    

    ##############################################################################
    # Select only 10000 particles for debugging
    num_particles_to_debug = 10000
    if num_particles > num_particles_to_debug:
       sampled_indices = np.random.choice(num_particles, num_particles_to_debug, replace=False)
       lon = lon[sampled_indices, :]
       lat = lat[sampled_indices, :]
     
    print(f"Debugging with {lon.shape[0]} particles", flush=True)

    num_particles, num_timesteps = lon.shape
    print(f"num of particles {num_particles}", flush=True)
    ###########################################################################

    num_polygons = len(gdf_all)
    transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)
    gdf_all = gdf_all.sort_values(by='id')

    print(f"Finding initial polys", flush=True)
    initial_polygons = []
    for i in range(num_particles):
        point_T0 = Point(lon[i, 0], lat[i, 0])
        initial_polygon = next((j for j in idx.intersection(point_T0.bounds) if point_T0.intersects(gdf_all.geometry[j])), None)
        initial_polygons.append(initial_polygon)
    
    print(f"Valid particles at t=0 w initial polygons: {num_particles}", flush=True)

    # Initialize the 3D numpy array to hold the transition matrices for each day
    all_daily_matrices = []

    # Loop over each timestep -- Section of the code which takes the longest
    for t in range(int(Tmin), num_days * 24):
        if t >= num_timesteps:
            break
        
        if t % 24 == 0:
            print(f"Processing day {t // 24}", flush=True)

        # Initialize a new transitions matrix for this timestep
        transitions_matrix = np.zeros((num_polygons, num_polygons), dtype=int)      

        # Loop over each particle        
        for i in range(num_particles):
            point_current = Point(lon[i, t], lat[i, t])
            
            if np.isnan(lon[i, t]) or np.isnan(lat[i, t]):  # if lon lat is not nan continue, otherwise it's beached or out of domain
                continue
 

            probability = P_t_interpolator(t / 24) # Check probability of settlement
            
            if np.random.rand() <= probability:
                for k in idx.intersection(point_current.bounds):
                    if point_current.within(gdf_all.geometry[k]):
                        transitions_matrix[initial_polygons[i], k] += 1
                        break  # exit once the sink polygon is found

        
        
        # Check if the transitions_matrix has a consistent shape
        print(f"Shape of transitions_matrix for day {t//24}: {transitions_matrix.shape}")
    
        # Append the transitions matrix for this timestep to the daily 3D array
        all_daily_matrices.append(transitions_matrix)

        # Once all 24 timesteps for a day are processed, combine into a 3D array and output
        if (t + 1) % 24 == 0:  # End of the day
            daily_array = np.stack(all_daily_matrices, axis=2)
            print(f"Shape of daily_array for day {t//24}: {daily_array.shape}")

            # Save the daily 3D numpy array
            daily_filename = f"{pathout}/transitions_matrix_day{t//24}_{year}.npy"
            np.save(daily_filename, daily_array)
            print(f"Saved daily transitions matrix for day {t//24} to {daily_filename}", flush=True)

           
            # Reset for the next day
            all_daily_matrices = []

    return transitions_matrix, num_timesteps



# Get most recent Tmin so can restart the run from the last outputted day
def get_most_recent_Tmin(pathout, year):
    # Find all the transition matrix files for the given year
    file_pattern = os.path.join(pathout, f'transitions_matrix_day*_{year}.npy')
    files = glob.glob(file_pattern)

    if not files:
        return 0  # If no files are found, start from day 0

    # Sort the files based on the day extracted from the filename
    files.sort(key=lambda x: int(x.split('day')[1].split('_')[0]))  # Extract day number and sort

    # Get the most recent file and determine the day (Tmin)
    most_recent_file = files[-1]
    most_recent_day = int(most_recent_file.split('day')[1].split('_')[0])

    # Set Tmin as the most recent day + 1 (to restart from the next day)
    Tmin = (most_recent_day + 1)*24
    print(f"Most recent file: {most_recent_file}, Tmin set to: {Tmin}")

    return Tmin


# Main to run the code
if __name__ == "__main__":
    # Path to netcdf particle tracking files
    base_path = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/SensitivityAnalysis/pout/'
    pathout = '/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/conmat/'

    # just running for a single year (i have 30 years to run!)
    years = range(1996,1997)
    
    # Reef polygon shapefile
    gdf_all = gpd.read_file("/scratch/pawsey0106/sbensadon/OceanParcels/CoralBay/Coral_communities/Regions_w_geom_FINAL_crop.shp")
    gdf_all = gdf_all.sort_values(by='id')
    idx = index.Index((j, geom.bounds, None) for j, geom in enumerate(gdf_all.geometry))

    # Process each file sequentially
    for year in years:
        Tmin = get_most_recent_Tmin(pathout, year)
        print(f"Using Tmin: {Tmin} for year {year}", flush = True)
        file_path = f'{base_path}/ParcelsOut_Diffusion_{year}.nc'
        transitions_matrix, num_timesteps = process_file(file_path, gdf_all, idx, Tmin=Tmin)  # Capture both outputs

    print(f"All transition matrices saved individually to {pathout}.")