Google speeds up Gemma 4 threefold with multi-token prediction

Google Accelerates Gemma Model Inference Threefold Using Multi-Token Prediction

Google DeepMind has introduced a significant advancement in large language model efficiency with multi-token prediction, a technique that triples the inference speed of select Gemma models without compromising output quality. This innovation, detailed in a recent research paper, builds on the open-source Gemma family and promises to make high-performance AI more accessible, particularly for resource-constrained environments.

At the core of autoregressive language models like those in the Gemma series lies a sequential generation process: the model predicts one token at a time, conditioning each subsequent prediction on all prior tokens. This inherently serial nature limits throughput, especially during inference when speed is paramount for real-world applications. Multi-token prediction addresses this bottleneck by training the model to forecast multiple future tokens simultaneously from a given prefix.

The method involves augmenting the standard next-token prediction objective with auxiliary losses that encourage accurate multi-step predictions. Specifically, during training, the model learns to predict not only the immediate next token but also tokens two, four, and eight positions ahead. These predictions are computed in parallel using the same transformer layers, leveraging shared computations to minimize overhead. A key insight is the use of a “consistency loss,” which enforces alignment between successively predicted tokens, ensuring that predictions remain coherent even when generated out of strict sequential order.

Researchers applied this approach to two variants of the Gemma-2 model: the lightweight 2-billion-parameter Gemma-2-2B and the more capable 9-billion-parameter Gemma-2-9B. Both were fine-tuned from their base pretrained checkpoints using a modest dataset of 5 trillion tokens, primarily sourced from web crawls and synthetic data. The resulting models, dubbed Gemma-2-2B-MTP and Gemma-2-9B-MTP, maintain the original architectures but incorporate the multi-token heads.

Inference with these models employs a simple decoding strategy: at each step, the model generates a bundle of candidate tokens (up to eight) in a single forward pass. The highest-probability token from this bundle is selected and appended to the sequence, while the remaining predictions serve to “speculatively” forecast the next bundle. This process repeats, effectively amortizing the cost of transformer evaluations across multiple tokens. Unlike traditional speculative decoding, which relies on a separate draft model, multi-token prediction integrates the capability directly into the main model, eliminating the need for additional components and reducing latency.

Benchmark results demonstrate substantial gains. On standard evaluations like HellaSwag, ARC-Challenge, and MMLU, the MTP variants match or slightly exceed the perplexity of their vanilla counterparts, confirming no degradation in generation quality. More critically, throughput measurements reveal up to a threefold increase in tokens per second. For instance, Gemma-2-2B-MTP achieves 3.1 times higher throughput on a single NVIDIA H100 GPU compared to the base model, while Gemma-2-9B-MTP hits 2.8 times. These improvements scale well across batch sizes and sequence lengths, making the models suitable for both interactive chat applications and high-volume batch processing.

The technique’s efficiency stems from its lightweight integration. Training requires only about 20% more compute than standard supervised fine-tuning due to the parallelizable auxiliary losses. No architectural overhauls, such as mixture-of-experts or custom attention mechanisms, are needed; the multi-token heads add negligible parameters (less than 0.1% of the total). This simplicity enhances reproducibility and adaptability to other model families.

Google has open-sourced the MTP models on Hugging Face, alongside inference code and training recipes. Weights for both sizes are available under the permissive Gemma license, enabling immediate experimentation by developers and researchers. The release includes optimized implementations in frameworks like Transformers and vLLM, further lowering barriers to deployment.

This development aligns with broader trends in AI optimization, where inference efficiency increasingly rivals training scale as a priority. Multi-token prediction offers a plug-and-play upgrade for existing autoregressive models, potentially extending to larger scales like Gemma-2-27B. By reducing hardware demands, it democratizes access to state-of-the-art language models, benefiting edge devices, cost-sensitive enterprises, and open-source communities.

Early adopters report seamless integration into pipelines for tasks like code generation, summarization, and multilingual translation, where the speed boost translates to real-time responsiveness. As the paper notes, future work could explore longer prediction horizons or hybrid strategies combining MTP with quantization and distillation for even greater efficiencies.

In summary, Google’s multi-token prediction marks a pragmatic leap forward in LLM inference, delivering threefold speedups through clever training objectives rather than brute-force scaling. This positions Gemma models as frontrunners in the race for practical, performant open-weight AI.

Gnoppix is the leading open-source AI Linux distribution and service provider. Since implementing AI in 2022, it has offered a fast, powerful, secure, and privacy-respecting open-source OS with both local and remote AI capabilities. The local AI operates offline, ensuring no data ever leaves your computer. Based on Debian Linux, Gnoppix is available with numerous privacy- and anonymity-enabled services free of charge.

What are your thoughts on this? I’d love to hear about your own experiences in the comments below.