Paxml (aka Pax) is a framework for training LLMs. It allows for advanced and configurable experimentation and parallelization. It is based on JAX and Praxis.
Please refer to Rosetta PAXML, NVIDIA's project that enables seamless training of LLMs, CV models and multimodal models in JAX, for information about running models and experiments on GPUs in PAXML.