Distance metrics
jaxscape.distance.AbstractDistance
¤
Abstract base class for distance computations on graphs.
Provides a unified interface for computing distances with automatic handling of
coordinate-based (for GridGraph) or index-based node specification.
Arguments:
graph: Graph on which to compute distances.sources: Source nodes as vertex indices (1D array) or coordinates (Nx2 array forGridGraph).targets: Target nodes as vertex indices (1D array) or coordinates (Nx2 array forGridGraph).nodes: Nodes for pairwise distances as vertex indices (1D) or coordinates (Nx2).
Specify either: nodes alone, sources and/or targets, or neither (for all-pairs).
Returns:
Distance array with shape depending on the inputs.
Example
from jaxscape import LCPDistance, GridGraph
import jax.numpy as jnp
distance = LCPDistance()
grid = GridGraph(permeability, fun=lambda x, y: (x + y) / 2)
# All-pairs distance
D = distance(grid) # Shape: (n_nodes, n_nodes)
# Using vertex indices
D = distance(grid, sources=jnp.array([0, 1]), targets=jnp.array([10, 20])) # Shape: (2, 2)
# Using coordinates (for GridGraph)
D = distance(grid, sources=jnp.array([[0, 0], [1, 1]]), targets=jnp.array([[10, 10]])) # Shape: (2, 1)
# Pairwise among subset
D = distance(grid, nodes=jnp.array([0, 5, 10])) # Shape: (3, 3)
jaxscape.euclidean_distance.EuclideanDistance
¤
Straight-line distance in grid coordinates. Only works with GridGraph.
Example
from jaxscape import EuclideanDistance
distance = EuclideanDistance()
dist = distance(grid, sources=source_coords, targets=target_coords)
__call__(graph: AbstractGraph, sources: typing.Optional[jax.Array] = None, targets: typing.Optional[jax.Array] = None, nodes: typing.Optional[jax.Array] = None) -> Array
¤
jaxscape.lcp_distance.LCPDistance
¤
Compute least-cost path distances using shortest path algorithms.
Currently supports two algorithms:
- Bellman-Ford (default): Efficient for sparse graphs and few sources. Complexity O(V × E × S) where S is the number of sources.
- Floyd-Warshall: Efficient for all-pairs on small dense graphs. Complexity O(V³), converts to dense matrix.
Parameters:
algorithm: Algorithm choice:"bellman-ford"(default) or"floyd-warshall".
Example
from jaxscape import LCPDistance, GridGraph
import jax.numpy as jnp
grid = GridGraph(permeability, fun=lambda x, y: (x + y) / 2)
# Default: Bellman-Ford (efficient for sparse graphs)
distance = LCPDistance()
D = distance(grid, sources=jnp.array([0, 1]), targets=jnp.array([10, 20]))
# Floyd-Warshall (efficient for small all-pairs)
distance_fw = LCPDistance(algorithm="floyd-warshall")
D_all = distance_fw(grid) # All-pairs distance
__call__(graph: AbstractGraph, sources: typing.Optional[jax.Array] = None, targets: typing.Optional[jax.Array] = None, nodes: typing.Optional[jax.Array] = None) -> Array
¤
jaxscape.resistance_distance.ResistanceDistance
¤
Compute the resistance distances.
Attributes:
solver: Optionallineax.AbstractLinearSolver. Must be compatible with BCOO matrices. We currently supportjaxscape.solvers.CholmodSolverandjaxscape.solvers.PyAMGSolver. If None, uses pseudo-inverse method, which is very memory intensive for large graphs (densifies the Laplacian matrix).
Example
from jaxscape import ResistanceDistance
from jaxscape.solvers import PyAMGSolver
# Default: pseudo-inverse (small graphs)
distance = ResistanceDistance()
# With solver (large graphs)
distance = ResistanceDistance(solver=PyAMGSolver())
dist = distance(grid)
Warning
The graph must be undirected for resistance distance to be well-defined.
__call__(graph: AbstractGraph, sources: typing.Optional[jax.Array] = None, targets: typing.Optional[jax.Array] = None, nodes: typing.Optional[jax.Array] = None) -> Array
¤
jaxscape.rsp_distance.RSPDistance
¤
Randomized shortest path distance. Requires the temperature parameter
theta and cost, which can be either a jax.experimental.sparse.BCOO
matrix or a function that will be used to map all non zero element of
the adjacency matrix to create the cost matrix. cost defaults to the
well adapted movement cost function lambda x: -jnp.log(x)).
Warning
This distance metric is experimental and may change in future releases.
Example
from jaxscape import RSPDistance
distance = RSPDistance(theta=0.01, cost=lambda x: -jnp.log(x))
dist = distance(grid)