This project aims to develop a high-performance algebraic multigrid (AMG) solver in JAX that supports automatic differentiation and GPU acceleration, addressing critical gaps in existing Python and Julia implementations. The work will enable efficient gradient-based optimization in scientific machine learning applications requiring large-scale linear system solves, with potential to significantly impact the JAX ecosystem for differentiable physics simulations and inverse problems.