April 30, 2026
MUHAMMAD GHIFARY
3D Gaussian Splatting (3DGS) has emerged as a state-of-the-art technique for real-time radiance field rendering, offering a compelling alternative to neural volume integration methods like NeRF. In the previous attempt, I implemented the 3DGS algorithm purely in JAX. However, that version was a naive approach and did not fully harness JAX’s performance on accelerators. There is still plenty of room to improve runtime.
This article discusses how to further optimize 3DGS across multiple Tensor Processing Units (TPUs). In general, the strategy involves restructuring the rasterization implementation and exploiting batched data parallelism. The jax-gs project addresses these challenges by reformulating 3DGS within the JAX framework and leveraging XLA (Accelerated Linear Algebra) to compile the entire training and rendering pipeline into highly optimized machine code.
This transition from a dynamic, CUDA-centric model to a static-shape, JIT-compiled architecture allows jax-gs to exploit the massive parallel processing power of TPUs while maintaining numerical stability and structural consistency. The codebase is also research-friendly that benefits from JAX’s composable transformations (e.g., vmap, pmap, grad).
Before diving into the optimization strategy, let’s briefly discuss the hardware accelerator itself. A basic understanding of TPU architecture will help us plan the optimization more effectively.
TPUs are Google's custom-developed application-specific integrated circuits (ASICs) designed specifically to accelerate machine learning workloads. Unlike general-purpose GPUs, TPUs are architected around the requirements of deep learning, prioritizing high-throughput matrix multiplications and low-latency interconnects.
Since their debut in 2015, TPUs have evolved from specialized inference engines into the backbone of global AI:

Here are the key architectural innovations in TPUs.