PonderTTT

Adaptive, budget-aware Test-Time Training (TTT) for code generation models built with JAX/Flax NNX.

Core Idea: Binary Gating via Gumbel-Softmax

PonderTTT introduces Adaptive Test-Time Training with learned SKIP/UPDATE decisions. Instead of applying TTT updates uniformly to all input chunks, we learn when to update using a binary gating mechanism trained via Gumbel-Softmax.

Feature Fixed TTT PonderTTT (Binary Gating)
Decision Always UPDATE SKIP or UPDATE per chunk
Training N/A Gumbel-Softmax (differentiable)
Inference Fixed cost True computational savings
Cost 3.0x (UPDATE_1) 2.67x (83% update rate)

Key Results (GPT-2 125M on Python)

Technical Architecture

This project is a pure JAX/Flax NNX rewrite of the official TTT-LM, enhanced with adaptive gating.

Roadmap & Status

The project is currently in active development. Phase 1 is complete with a preprint available.

Phase 1: Complete (Preprint)

Phase 2: Planned (Conference Submission)

See PLAN.md for detailed roadmap.

Quick Start

Installation

# Install uv if you do not have it yet
curl -LsSf https://astral.sh/uv/install.sh | sh

# Install the project in editable mode
uv pip install -e . --group gpu # or tpu/cpu

Reproduce Paper Results (Recommended)

Run the full suite of experiments (Training, OOD Evaluation, Latency, Ablations) with a single script:

chmod +x scripts/run_all_experiments.sh
./scripts/run_all_experiments.sh

Manual Training

python -m ponderttt.experiments.train_hard_skip \
    --model_scale 125m \
    --target_update_rate 0.5 \
    --num_iterations 10000 \
    --output_dir outputs/hard_skip

Citation

@article{sim2025ponderttt, title={Learning to Ponder: Adaptive Compute Allocation via Test-Time Training}, author={Sim, Gihyeon}, year={2025} }