Skip to content

Distance metrics

jaxscape.distance.AbstractDistance ¤

Abstract base class for distance computations on graphs.

Provides a unified interface for computing distances with automatic handling of coordinate-based (for GridGraph) or index-based node specification.

Arguments:

  • graph: Graph on which to compute distances.
  • sources: Source nodes as vertex indices (1D array) or coordinates (Nx2 array for GridGraph).
  • targets: Target nodes as vertex indices (1D array) or coordinates (Nx2 array for GridGraph).
  • nodes: Nodes for pairwise distances as vertex indices (1D) or coordinates (Nx2).

Specify either: nodes alone, sources and/or targets, or neither (for all-pairs).

Returns:

Distance array with shape depending on the inputs.

Example

from jaxscape import LCPDistance, GridGraph
import jax.numpy as jnp

distance = LCPDistance()
grid = GridGraph(permeability, fun=lambda x, y: (x + y) / 2)

# All-pairs distance
D = distance(grid)  # Shape: (n_nodes, n_nodes)

# Using vertex indices
D = distance(grid, sources=jnp.array([0, 1]), targets=jnp.array([10, 20]))  # Shape: (2, 2)

# Using coordinates (for GridGraph)
D = distance(grid, sources=jnp.array([[0, 0], [1, 1]]), targets=jnp.array([[10, 10]]))  # Shape: (2, 1)

# Pairwise among subset
D = distance(grid, nodes=jnp.array([0, 5, 10]))  # Shape: (3, 3)

jaxscape.euclidean_distance.EuclideanDistance ¤

Straight-line distance in grid coordinates. Only works with GridGraph.

Example

from jaxscape import EuclideanDistance

distance = EuclideanDistance()
dist = distance(grid, sources=source_coords, targets=target_coords)
__call__(graph: AbstractGraph, sources: typing.Optional[jax.Array] = None, targets: typing.Optional[jax.Array] = None, nodes: typing.Optional[jax.Array] = None) -> Array ¤

jaxscape.lcp_distance.LCPDistance ¤

Compute least-cost path distances using shortest path algorithms.

Currently supports two algorithms:

  • Bellman-Ford (default): Efficient for sparse graphs and few sources. Complexity O(V × E × S) where S is the number of sources.
  • Floyd-Warshall: Efficient for all-pairs on small dense graphs. Complexity O(V³), converts to dense matrix.

Parameters:

  • algorithm: Algorithm choice: "bellman-ford" (default) or "floyd-warshall".

Example

from jaxscape import LCPDistance, GridGraph
import jax.numpy as jnp

grid = GridGraph(permeability, fun=lambda x, y: (x + y) / 2)

# Default: Bellman-Ford (efficient for sparse graphs)
distance = LCPDistance()
D = distance(grid, sources=jnp.array([0, 1]), targets=jnp.array([10, 20]))

# Floyd-Warshall (efficient for small all-pairs)
distance_fw = LCPDistance(algorithm="floyd-warshall")
D_all = distance_fw(grid)  # All-pairs distance
__call__(graph: AbstractGraph, sources: typing.Optional[jax.Array] = None, targets: typing.Optional[jax.Array] = None, nodes: typing.Optional[jax.Array] = None) -> Array ¤

jaxscape.resistance_distance.ResistanceDistance ¤

Compute the resistance distances.

Attributes:

  • solver: Optional lineax.AbstractLinearSolver. Must be compatible with BCOO matrices. We currently support jaxscape.solvers.CholmodSolver and jaxscape.solvers.PyAMGSolver. If None, uses pseudo-inverse method, which is very memory intensive for large graphs (densifies the Laplacian matrix).

Example

from jaxscape import ResistanceDistance
from jaxscape.solvers import PyAMGSolver

# Default: pseudo-inverse (small graphs)
distance = ResistanceDistance()

# With solver (large graphs)
distance = ResistanceDistance(solver=PyAMGSolver())

dist = distance(grid)

Warning

The graph must be undirected for resistance distance to be well-defined.

__call__(graph: AbstractGraph, sources: typing.Optional[jax.Array] = None, targets: typing.Optional[jax.Array] = None, nodes: typing.Optional[jax.Array] = None) -> Array ¤

jaxscape.rsp_distance.RSPDistance ¤

Randomized shortest path distance. Requires the temperature parameter theta and cost, which can be either a jax.experimental.sparse.BCOO matrix or a function that will be used to map all non zero element of the adjacency matrix to create the cost matrix. cost defaults to the well adapted movement cost function lambda x: -jnp.log(x)).

Warning

This distance metric is experimental and may change in future releases.

Example

from jaxscape import RSPDistance

distance = RSPDistance(theta=0.01, cost=lambda x: -jnp.log(x))
dist = distance(grid)
__call__(graph: AbstractGraph, sources: typing.Optional[jax.Array] = None, targets: typing.Optional[jax.Array] = None, nodes: typing.Optional[jax.Array] = None) -> Array ¤