Skip to content

Moving window operations¤

Process large rasters efficiently using spatial decomposition. JAXScape's WindowOperation divides landscapes into overlapping windows with buffer zones, enabling memory-efficient sequential processing or GPU-accelerated batch operations.

Prerequisites¤

pip install jaxscape rasterio matplotlib
import jax.numpy as jnp
import matplotlib.pyplot as plt
import rioxarray
from jaxscape import padding, WindowOperation

Load habitat quality raster for demonstration.

# Load habitat quality raster
raster = rioxarray.open_rasterio("../data/suitability.tif", masked=True)
quality = jnp.array(
    raster.sel(band=1).fillna(0), dtype="float32"
)  # Replace no-data values with 0

print(f"Raster shape: {quality.shape}")
print(f"Value range: [{quality.min():.2f}, {quality.max():.2f}]")
Raster shape: (107, 182)
Value range: [0.00, 100.00]
/Users/victorboussange/projects/jaxscape/.venv/lib/python3.12/site-packages/pyproj/crs/_cf1x8.py:515: UserWarning: angle from rectified to skew grid parameter lost in conversion to CF
  warnings.warn(
# Visualize the full raster
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(quality, cmap="viridis")
ax.set_title("Habitat quality raster", fontsize=14, pad=20)
ax.axis("off")
plt.colorbar(im, ax=ax, label="Quality", shrink=0.7)
plt.tight_layout()
plt.show()

img

Initialize window operation¤

Window parameters define the tile size and overlap. Before feeding the raster to the WindowOperation class, we can use jaxscape.padding to ensure that windows have complete neighborhoods; WindowOperation will error otherwise.

buffer_size = 10  # Overlap zone to handle edge effects
window_size = 50  # Core area of each window (in pixels)

quality_padded = padding(quality, buffer_size, window_size)

print(f"\nOriginal raster shape: {quality.shape}")
print(f"Padded raster shape: {quality_padded.shape}")

# Initialize window operation manager
window_op = WindowOperation(
    shape=quality_padded.shape, window_size=window_size, buffer_size=buffer_size
)

print(f"Total number of windows: {window_op.nb_steps}")
Original raster shape: (107, 182)
Padded raster shape: (120, 220)
Total number of windows: 8

Lazy iterator: memory-efficient processing¤

The lazy iterator yields windows one at a time (via Python generator), ideal for memory-constrained environments or sequential processing. Each iteration provides window coordinates (x, y) and the data array. This enables processing arbitrarily large rasters as long as individual windows fit in memory.

# Visualize all windows using lazy iteration
n_windows = window_op.nb_steps
n_cols = 4
n_rows = int(jnp.ceil(n_windows / n_cols))

fig, axs = plt.subplots(n_rows, n_cols, figsize=(12, 3 * n_rows))
axs = axs.flatten() if n_windows > 1 else [axs]

for i, (xy, window) in enumerate(window_op.lazy_iterator(quality_padded)):
    ax = axs[i]
    im = ax.imshow(window, cmap="viridis")
    ax.set_title(f"Window {i+1}\nPosition: ({xy[0]}, {xy[1]})", fontsize=10)
    ax.axis("off")

    # Draw rectangle showing core area (excluding buffer)
    from matplotlib.patches import Rectangle

    rect = Rectangle(
        (buffer_size, buffer_size),
        window_size,
        window_size,
        linewidth=2,
        edgecolor="red",
        facecolor="none",
    )
    ax.add_patch(rect)

# Hide unused subplots
for j in range(i + 1, len(axs)):
    axs[j].axis("off")

plt.tight_layout()
plt.show()

print(f"Processed {i+1} windows (red rectangles = core areas)")

img

Processed 8 windows (red rectangles = core areas)

Eager iterator: vectorized batch processing¤

The eager iterator loads all windows into memory at once, returning coordinate and data arrays with shapes (n_windows, 2) and (n_windows, height, width). This enables vectorized operations across all windows simultaneously via jax.vmap, dramatically accelerating GPU/TPU processing. The trade-off is higher memory consumption.

# Load all windows at once for batch processing
xy_coords, windows = window_op.eager_iterator(quality_padded)

print(f"Coordinates shape: {xy_coords.shape}")
print(f"Windows shape: {windows.shape}")
Coordinates shape: (8, 2)
Windows shape: (8, 70, 70)

Update raster with processed windows¤

The update_raster_with_focal_window method replaces a window's core area (excluding buffers) with new data. It automatically handles buffer zones by writing only core regions, maintaining spatial continuity across window boundaries.

# Example: Replace window 3's core area with ones (demonstrating update)
window_index = 2  # Zero-indexed: window 3
new_window = jnp.ones(windows[window_index].shape, dtype="float32")

# Update the padded raster
updated_raster = window_op.update_raster_with_core_window(
    xy_coords[window_index], quality_padded, new_window
)

# Visualize the updated raster
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(updated_raster, cmap="viridis")
ax.set_title(f"Updated raster (window {window_index + 1} modified)", fontsize=12)
ax.axis("off")
plt.tight_layout()
plt.show()

img