Developing an algebraic multigrid solver in JAX
Multigrid methods represent the state-of-the-art for solving large-scale linear systems arising from discretized partial differential equations, offering optimal computational complexity for many problem classes.
Established implementations such as pyAMG and AMG.jl provide robust solvers but lack two critical capabilities for modern scientific machine learning: GPU acceleration and automatic differentiation compatibility. These features are essential for scientific machine learning workflows where differentiable simulation components (e.g., neural networks embedded in physical models) require efficient iterative solves with gradient backpropagation for end-to-end optimization.
This project aims to develop an algebraic multigrid (AMG) solver in JAX that natively supports automatic differentiation and GPU acceleration. The work involves analyzing existing Python and Julia implementations to design an architecture compatible with JAX’s functional programming paradigm and just-in-time compilation model. A successful implementation could have substantial impact on the JAX ecosystem, from accelerating finite element packages to accelerating ecological connectivity analysis tools.
The project scope is flexible and can emphasize software engineering or algorithmic optimization depending on the student’s background and interests. Prior experience with JAX or advanced numerical linear algebra is beneficial but not required; students will gain deep expertise in iterative solvers, functional programming patterns, and best practices for scientific open-source software development.