Advanced: custom window operations¤
jaxscape.utils.padding(raster: Array, buffer_size: int, window_size: int) -> Array
¤
Pad raster to ensure dimensions are compatible with WindowOperation.
Ensures (raster.shape[i] - 2 * buffer_size) % window_size == 0.
Example
from jaxscape.utils import padding
raster = jnp.ones((100, 100))
padded = padding(raster, buffer_size=10, window_size=25)
jaxscape.window_operation.WindowOperation
¤
Manages window-based operations on raster data with buffering.
Used for processing large rasters by dividing them into smaller windows with overlapping buffer regions. Ensures each window has sufficient context for operations that depend on neighboring pixels.
Attributes:
shape: Raster dimensions(height, width).window_size: Core window size in pixels.buffer_size: Buffer region size around each core window.total_window_size: Total window size including buffers(window_size + 2 * buffer_size).x_steps,y_steps: Number of windows in each dimension.
Example
import jax.numpy as jnp
from jaxscape import WindowOperation
raster = jnp.ones((100, 100))
window_op = WindowOperation(
shape=raster.shape,
window_size=20,
buffer_size=10
)
Warning
You must ensure that (shape[i] - 2 * buffer_size) is divisible by
window_size for both dimensions i = 0, 1. Consider using jaxscape.utils.padding
to pad your raster data automatically.
nb_steps
property
¤
Total number of windows in the raster.
Example
window_op = WindowOperation(shape=(100, 100), window_size=20, buffer_size=10)
print(window_op.nb_steps) # 25 (5x5 grid of windows)
extract_total_window(xy: Array, raster: Array) -> Array
¤
Extract a window including buffer regions from the raster.
Arguments:
xy: Start coordinates[x, y]of the window.raster: 2D raster array.
Returns:
Window of shape (total_window_size, total_window_size).
Example
window_op = WindowOperation(shape=(100, 100), window_size=20, buffer_size=10)
raster = jnp.ones((100, 100))
window = window_op.extract_total_window(jnp.array([0, 0]), raster)
# window.shape = (40, 40)
extract_core_window(xy: Array, raster: Array) -> Array
¤
Extract the core window without buffers from the raster.
Arguments:
xy: Start coordinates[x, y]of the total window.raster: 2D raster array.
Returns:
Core window of shape (window_size, window_size).
Example
window_op = WindowOperation(shape=(100, 100), window_size=20, buffer_size=10)
raster = jnp.ones((100, 100))
core = window_op.extract_core_window(jnp.array([0, 0]), raster)
# core.shape = (20, 20)
update_raster_with_core_window(xy: Array, raster: Array, raster_window: Array, fun: collections.abc.Callable[[jax.Array, jax.Array], jax.Array] = <lambda>) -> Array
¤
Update raster by merging the core region of a processed window.
Extracts the core (non-buffer) region from raster_window and updates the
corresponding region in raster using the provided function.
Arguments:
xy: Start coordinates[x, y]of the total window.raster: Full raster array to update.raster_window: Processed window including buffers.fun: Function to combine current and new values. Defaults to replacement.
Returns:
Updated raster array.
Example
window_op = WindowOperation(shape=(100, 100), window_size=20, buffer_size=10)
raster = jnp.zeros((100, 100))
for xy, window in window_op.lazy_iterator(raster):
processed = compute_distance(window)
raster = window_op.update_raster_with_core_window(xy, raster, processed)
update_raster_with_window(xy: Array, raster: Array, raster_window: Array, fun: collections.abc.Callable[[jax.Array, jax.Array], jax.Array] = <lambda>) -> Array
¤
Update raster with the entire window including buffers.
Arguments:
xy: Start coordinates[x, y]of the window.raster: Full raster array to update.raster_window: Processed window to merge.fun: Function to combine current and new values. Defaults to replacement.
Returns:
Updated raster array.
Example
window_op = WindowOperation(shape=(100, 100), window_size=20, buffer_size=10)
raster = jnp.zeros((100, 100))
window_data = jnp.ones((40, 40))
# Replace window region
raster = window_op.update_raster_with_window(
jnp.array([0, 0]), raster, window_data
)
# Accumulate with custom function
raster = window_op.update_raster_with_window(
jnp.array([0, 0]), raster, window_data, fun=jnp.add
)
lazy_iterator(raster: Array) -> collections.abc.Generator[tuple[jax.Array, jax.Array], None, None]
¤
Iterate over windows one at a time.
Memory-efficient iteration that yields windows sequentially without pre-computing all windows.
Arguments:
raster: 2D raster array to iterate over.
Yields:
Tuples of (xy, window) where xy are start coordinates and window
is the extracted window with buffers.
Example
window_op = WindowOperation(shape=(100, 100), window_size=20, buffer_size=10)
raster = jnp.ones((100, 100))
for xy, window in window_op.lazy_iterator(raster):
# Process each window sequentially
result = compute(window)
# window.shape = (40, 40)
eager_iterator(matrix: Array) -> tuple[jax.Array, jax.Array]
¤
Extract all windows at once for parallel processing.
Pre-computes all windows in a single operation using vmap, enabling
efficient batch processing and parallelization.
Arguments:
matrix: 2D input raster array.
Returns:
Tuple (xy, windows) where xy has shape (num_windows, 2) containing
start coordinates, and windows has shape (num_windows, window_height, window_width).
Example
window_op = WindowOperation(shape=(100, 100), window_size=20, buffer_size=10)
raster = jnp.ones((100, 100))
xy, windows = window_op.eager_iterator(raster)
# xy.shape = (25, 2), windows.shape = (25, 40, 40)
# Process all windows in parallel
results = jax.vmap(compute)(windows)
jaxscape.windowed_analysis.WindowedAnalysis
¤
Base class for windowed connectivity analyses on large rasters.
Processes landscapes through hierarchical decomposition: batches (processed sequentially) contain windows (processed in parallel) with buffer zones for spatial dependencies.
Parameters:
quality_raster: Habitat quality values.permeability_raster: Movement permeability, or functionquality -> permeability.distance: Distance metric (e.g.,LCPDistance(),ResistanceDistance()).proximity: Distance-to-proximity transform (e.g.,lambda d: jnp.exp(-d/D)).coarsening_factor: Spatial coarsening in [0, 1]. 0 = finest resolution, higher = faster.dependency_range: Spatial dependency range in pixels (buffer size).batch_size: Number of coarsened windows per batch (higher = more memory).
Example
from jaxscape.connectivity_analysis import WindowedAnalysis
class CustomConnectivity(WindowedAnalysis):
def run(self, **kwargs):
result = 0.0
for xy_batch, quality_batch in self.batch_op.lazy_iterator(self.quality_raster):
result += process_batch(quality_batch)
return result