starting earliest 1 January 2025
Supervisor: Joanna Sliwa
A Hessian of a model stores second-order partial derivatives of a loss function with regards to the models' weights. In its pure form, it is computationally infeasible. This project aims to create a library focused on efficiently approximating Hessian matrices in JAX. The library will implement key approximation techniques, such as Generalized Gauss Newton, Kronecker Factored Approximation Curvature, and the diagonal of a Fisher Information matrix. The primary goals include achieving computational and memory efficiency while maintaining a user-friendly design. By providing practical and accessible implementations within the JAX framework, this project seeks to offer a valuable tool for researchers and practitioners in various applications, including continual learning and model training.
Prerequisites:
- experience in JAX
- knowledge of deep learning
- knowledge of linear algebra