Cookies on this website

We use cookies to ensure that we give you the best experience on our website. If you click 'Accept all cookies' we'll assume that you are happy to receive all cookies and you won't see this message again. If you click 'Reject all non-essential cookies' only necessary cookies providing core functionality such as security, network management, and accessibility will be enabled. Click 'Find out more' for information on how to change your cookie settings.

Benchmarks play an important role in the development of machine learning algorithms, with reinforcement learning (RL) research having been heavily influenced by the available environments. However, RL environments are traditionally run on the CPU, limiting their scalability with typical academic compute. Recent advancements in JAX have enabled the wider use of hardware acceleration to overcome these computational hurdles, enabling massively parallel RL training pipelines and environments. This is particularly useful for multi-agent reinforcement learning (MARL) research. First of all, multiple agents must be considered at each environment step, adding computational burden, and secondly, the sample complexity is increased due to non-stationarity, decentralised partial observability, or other MARL challenges. In this paper, we present JaxMARL, the first open-source code base that combines ease-of-use with GPU enabled efficiency, and supports a large number of commonly used MARL environments as well as popular baseline algorithms. When considering wall clock time, our experiments show that per-run our JAX-based training pipeline is up to 12500x faster than existing approaches. We also introduce and benchmark SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. We provide code at https://github.com/flairox/jaxmarl.

Type

Publisher

Association for Computing Machinery

Publication Date

06/05/2024

Pages

2444 - 2446