Google Gemma + JAX Agent Cost Control: XLA Recompilation Loops, Device Placement Copies, Model Eviction Failures, and JIT Cache Loss
Google Gemma models — the 2B, 7B, 9B, and 27B open-weight variants — are designed to be run locally. Combined with JAX, Google's numerical computing library that compiles Python functions to XLA kernels for GPU, TPU, or CPU execution, they offer a self-hosted inference stack that eliminates cloud API billing and keeps inference latency under your control. The official Gemma JAX implementation and community ports via Hugging Face Transformers with JAX backend make this stack accessible to any developer with a consumer GPU or Google Cloud TPU slice.
The absence of per-token invoices from a provider creates a dangerous blind spot. JAX's execution model is fundamentally different from PyTorch: functions are traced once per unique input shape to produce XLA computation graphs, then compiled by the XLA compiler before execution. This design delivers outstanding throughput on homogeneous workloads — the same batch size, the same sequence length, every time. Agentic pipelines are the opposite: tool calls return results of unpredictable length, prompts grow across turns, and retry logic introduces new shapes on every iteration. Four failure modes emerge specifically from this mismatch between JAX's static-shape optimization model and the dynamic, variable-length nature of agent execution.
Four failure modes specific to Gemma + JAX agentic pipelines:
- XLA recompilation on every unique sequence length — JAX traces and compiles a new kernel for each unique input shape. Variable-length tool results produce a unique shape on every agent step.
- Silent CPU-to-GPU tensor copies inside agent loops — Mixing numpy arrays, Python lists, or CPU tensors with JAX GPU arrays inside a loop causes implicit, full-tensor device migrations on every iteration.
- Gemma size-switching without model eviction — Routing short prompts to Gemma-2B and complex queries to Gemma-27B loads both parameter sets into GPU/TPU memory simultaneously if the previous model is not explicitly deleted and the accelerator buffers flushed.
- JIT cache loss from subprocess agent isolation — Agent frameworks that spawn a new subprocess per request lose JAX's in-process compiled kernel cache, forcing full XLA recompilation on every agent invocation.
Failure Mode 1 — XLA Recompilation on Every Unique Sequence Length
JAX transforms Python functions into abstract computation graphs via jax.jit. The first call to a @jax.jit-decorated function with a given set of input shapes triggers two phases: tracing, which walks your Python code symbolically to record all operations, and XLA compilation, which lowers the traced graph to hardware-specific machine code. On a Gemma-2B model, this compilation takes 15–45 seconds on a consumer GPU. On Gemma-27B, expect 90–180 seconds on an A100.
The cache key for the compiled kernel is the full shape of every input array — including the sequence length axis. A tool response that returns 847 tokens produces a different shape than one that returns 1,024 tokens. An agent that runs ten loop iterations with ten different accumulated context lengths triggers ten separate compilations. The first call after warmup may take 15 seconds; ten unique shapes accumulate 150–450 seconds of XLA work across a single agent run.
The standard mitigation is input padding to fixed bucket boundaries. Instead of passing the raw sequence length to the model, pad all inputs to the nearest power of two (or a curated set of lengths like 512, 1024, 2048, 4096, 8192). JAX only recompiles when a new bucket boundary is crossed — typically 4–6 unique lengths for most agent workflows, instead of one per call.
Shape trace cost rule of thumb: Plan for at most 6 unique sequence-length buckets per model. Each bucket costs one XLA compilation. Budget for the first call into each bucket — any subsequent calls reuse the cached kernel at full throughput speed.
import jax
import jax.numpy as jnp
import time
import math
from dataclasses import dataclass, field
from typing import Optional
SEQUENCE_BUCKETS = [512, 1024, 2048, 4096, 8192, 16384]
def pad_to_bucket(token_ids: list[int], pad_token_id: int = 0) -> tuple[list[int], int]:
actual_len = len(token_ids)
for bucket in SEQUENCE_BUCKETS:
if actual_len <= bucket:
padding = [pad_token_id] * (bucket - actual_len)
return token_ids + padding, actual_len
raise ValueError(f"Sequence length {actual_len} exceeds max bucket {SEQUENCE_BUCKETS[-1]}")
@dataclass
class XLARecompileGuard:
model_id: str
compiled_buckets: set = field(default_factory=set)
compilation_times: dict = field(default_factory=dict)
max_unique_buckets: int = 8
def check_and_record(self, sequence_length: int) -> Optional[str]:
bucket = next((b for b in SEQUENCE_BUCKETS if sequence_length <= b), None)
if bucket is None:
return f"sequence length {sequence_length} exceeds all configured buckets — will trigger OOM or fallback"
if bucket not in self.compiled_buckets:
if len(self.compiled_buckets) >= self.max_unique_buckets:
return (
f"bucket {bucket} not yet compiled and {len(self.compiled_buckets)} buckets already "
f"compiled — triggering an additional XLA compilation now"
)
return None
def record_compilation(self, bucket: int, duration_s: float):
self.compiled_buckets.add(bucket)
self.compilation_times[bucket] = duration_s
def total_compilation_s(self) -> float:
return sum(self.compilation_times.values())
# Usage: wrap your Gemma forward pass
def guarded_generate(model_fn, token_ids: list[int], guard: XLARecompileGuard):
padded, actual_len = pad_to_bucket(token_ids)
bucket = next(b for b in SEQUENCE_BUCKETS if len(padded) <= b)
warning = guard.check_and_record(len(padded))
if warning:
print(f"[XLARecompileGuard] WARNING: {warning}")
if bucket not in guard.compiled_buckets:
start = time.perf_counter()
result = model_fn(jnp.array([padded])) # triggers compile
elapsed = time.perf_counter() - start
guard.record_compilation(bucket, elapsed)
print(f"[XLARecompileGuard] compiled bucket={bucket} in {elapsed:.1f}s")
else:
result = model_fn(jnp.array([padded])) # cache hit
# strip padding from result
return result[:, :actual_len]
Failure Mode 2 — Silent CPU-to-GPU Tensor Copies Inside Agent Loops
JAX arrays carry a device attribute that records where the array lives: CpuDevice, CudaDevice(id=0), or a TPU device. When you perform an operation that mixes arrays from different devices — for example, concatenating a JAX GPU array with a numpy array, or passing a Python list to a JAX function inside a loop — JAX silently migrates the non-GPU data to the target device before computation. The migration is a full memory copy across the PCIe bus, which saturates at roughly 12–16 GB/s on a PCIe 4.0 ×16 link.
The cost in absolute terms: a Gemma-7B KV cache for a 4,096-token context occupies approximately 1.8 GB (32 layers × 2 heads × 128 head dim × 4096 tokens × 2 bytes per bfloat16 value). A single PCIe copy of this KV cache takes roughly 100–150 ms. An agent loop that runs 20 iterations with a stale numpy KV accumulation pattern copies 1.8 GB × 20 = 36 GB across the PCIe bus, adding 2–3 seconds of pure transfer overhead — entirely separate from inference latency.
The failure is insidious because JAX does not warn about device placement mismatches. The operation succeeds silently; the only observable symptom is latency that scales linearly with loop count and disappears when you pin arrays correctly.
import jax
import jax.numpy as jnp
import numpy as np
from dataclasses import dataclass, field
@dataclass
class DevicePlacementGuard:
target_device: str # e.g. "gpu:0" or "tpu:0"
copy_events: list = field(default_factory=list)
copy_bytes_total: int = 0
def check_array(self, name: str, arr) -> bool:
"""Returns True if the array is on the wrong device."""
if isinstance(arr, np.ndarray):
self.copy_events.append({"name": name, "from": "cpu/numpy", "to": self.target_device})
size_bytes = arr.nbytes
self.copy_bytes_total += size_bytes
print(
f"[DevicePlacementGuard] WARN: '{name}' is a numpy array ({size_bytes/1e6:.1f} MB) "
f"— will be copied to {self.target_device} before compute"
)
return True
if hasattr(arr, 'device'):
device_str = str(arr.device())
if self.target_device not in device_str:
self.copy_events.append({"name": name, "from": device_str, "to": self.target_device})
size_bytes = arr.nbytes if hasattr(arr, 'nbytes') else 0
self.copy_bytes_total += size_bytes
print(
f"[DevicePlacementGuard] WARN: '{name}' on {device_str} "
f"but target is {self.target_device} — implicit copy triggered"
)
return True
return False
def pin_to_device(self, arr, device=None):
"""Explicitly move array to target device — use at loop boundaries."""
target = device or self.target_device
if isinstance(arr, np.ndarray):
return jax.device_put(arr, jax.devices(target.split(':')[0])[int(target.split(':')[1]) if ':' in target else 0])
return arr
def report(self) -> str:
if not self.copy_events:
return "no device placement mismatches detected"
return (
f"{len(self.copy_events)} implicit device copies detected, "
f"{self.copy_bytes_total / 1e9:.2f} GB transferred total"
)
# Pattern: pin arrays to GPU at the start of each agent loop iteration
def agent_loop_step(kv_cache, new_tokens, guard: DevicePlacementGuard):
# Always call pin_to_device before passing arrays into JAX functions
kv_cache = guard.pin_to_device(kv_cache)
new_tokens = guard.pin_to_device(new_tokens)
guard.check_array("kv_cache", kv_cache)
guard.check_array("new_tokens", new_tokens)
# ... inference call here ...
return kv_cache, new_tokens
Failure Mode 3 — Gemma Size-Switching Without Model Eviction
Gemma's model family spans four sizes: 2B, 7B, 9B, and 27B parameters. Some agent architectures route dynamically between sizes — short, low-complexity prompts go to Gemma-2B for speed; multi-step reasoning or long-context synthesis goes to Gemma-27B. This is a sound architectural idea. The implementation failure is assuming that loading a new model weight set automatically evicts the previous one from accelerator memory.
JAX does not have a global model registry or an automatic least-recently-used memory manager for accelerator buffers. When you load Gemma-27B parameters as a new set of JAX arrays, the previous Gemma-2B arrays remain resident on the GPU until the Python objects holding references to them are garbage collected and JAX's buffer donation system releases the underlying memory. On GPUs with limited VRAM (a 24 GB RTX 4090 holds Gemma-2B at fp16 with room to spare, but Gemma-7B takes 14 GB and Gemma-27B at bfloat16 requires 54 GB), loading a second model without explicit eviction triggers an out-of-memory error on the third or fourth call — or forces compute to fall back to the host CPU, which is 40–60× slower than GPU inference.
The correct pattern is to explicitly delete the previous model's parameter pytree and call jax.clear_backends() or allow the garbage collector to collect the released references before loading the next model. A ModelEvictionGuard tracks which model is currently loaded and enforces a single-model-in-memory invariant:
import jax
import gc
from dataclasses import dataclass, field
from typing import Any, Optional
# Approximate GPU memory footprint per model size (bfloat16, single device)
GEMMA_VRAM_GB = {
"gemma-2b": 5.4,
"gemma-7b": 14.0,
"gemma-9b": 18.0,
"gemma-27b": 54.0,
}
@dataclass
class ModelEvictionGuard:
available_vram_gb: float # query once at startup
current_model_id: Optional[str] = None
current_params: Any = None
eviction_log: list = field(default_factory=list)
def can_load(self, model_id: str) -> tuple[bool, str]:
required = GEMMA_VRAM_GB.get(model_id, float('inf'))
current_used = GEMMA_VRAM_GB.get(self.current_model_id, 0.0) if self.current_model_id else 0.0
headroom = self.available_vram_gb - current_used
if required <= headroom:
return True, f"fits with {headroom - required:.1f} GB headroom"
return (
False,
f"{model_id} needs {required:.1f} GB but only {headroom:.1f} GB free "
f"({self.available_vram_gb:.1f} GB total, {current_used:.1f} GB in use by {self.current_model_id})"
)
def evict_current(self):
if self.current_params is not None:
evicted = self.current_model_id
self.eviction_log.append(evicted)
# Release the pytree; JAX frees the underlying XLA buffers when
# reference count drops to zero.
del self.current_params
self.current_params = None
self.current_model_id = None
gc.collect()
# Give JAX a chance to process the deallocation queue
jax.effects_barrier()
print(f"[ModelEvictionGuard] evicted {evicted} from accelerator memory")
def register_loaded(self, model_id: str, params: Any):
ok, reason = self.can_load(model_id)
if not ok:
raise RuntimeError(
f"[ModelEvictionGuard] refused to load {model_id}: {reason}. "
"Call evict_current() first."
)
self.current_params = params
self.current_model_id = model_id
print(f"[ModelEvictionGuard] registered {model_id} — {reason}")
def load_gemma_with_eviction(
model_id: str,
load_fn, # callable that loads and returns (model, params)
guard: ModelEvictionGuard,
):
if guard.current_model_id == model_id:
# Same model already loaded — skip eviction + reload
return guard.current_params
ok, reason = guard.can_load(model_id)
if not ok:
print(f"[ModelEvictionGuard] evicting {guard.current_model_id} to make room")
guard.evict_current()
params = load_fn(model_id)
guard.register_loaded(model_id, params)
return params
Failure Mode 4 — JIT Cache Loss from Subprocess Agent Isolation
JAX's compiled kernel cache lives entirely in process memory. The mapping from (function hash, input shapes, dtype, device) to compiled XLA executables is stored in a Python dictionary inside the active Python process. When that process exits — or when an agent framework spawns a new subprocess per request for isolation, timeout enforcement, or memory accounting — the entire compiled kernel cache is destroyed. The next request starts from scratch: trace, compile, execute.
This is the most damaging failure mode in production because it is completely invisible until you measure it. The first request against a fresh subprocess takes 30–180 seconds depending on Gemma model size. The agent framework's timeout may fire before the first compilation finishes, triggering a retry — which starts another subprocess and another compilation. The resulting spiral: request → timeout → subprocess kill → new subprocess → recompile → timeout → kill → repeat.
JAX does provide a persistent compilation cache via jax.config.update("jax_compilation_cache_dir", "/path/to/cache"). When this is configured, compiled XLA artifacts are serialized to disk and reloaded on cache hits, bypassing the XLA compiler entirely for previously seen shapes. Cold cache still pays the compilation cost once; warm cache reduces a 45-second compilation to a 200–400 ms disk read. The timeout that fires on the first call of a fresh process after a subprocess kill becomes a non-issue once the disk cache is warmed.
import jax
import os
import time
import hashlib
import json
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional
# Always configure persistent compilation cache before importing model code
JAX_CACHE_DIR = os.environ.get("JAX_COMPILATION_CACHE_DIR", "/tmp/jax_compilation_cache")
def configure_jax_persistent_cache(cache_dir: str = JAX_CACHE_DIR):
Path(cache_dir).mkdir(parents=True, exist_ok=True)
jax.config.update("jax_compilation_cache_dir", cache_dir)
print(f"[JITCacheGuard] persistent cache at {cache_dir}")
@dataclass
class JITCacheGuard:
cache_dir: str
process_start_time: float = field(default_factory=time.time)
warmup_completed: bool = False
warmup_duration_s: Optional[float] = None
cache_hits: int = 0
cache_misses: int = 0
def _cache_key(self, model_id: str, shapes: dict) -> str:
payload = json.dumps({"model": model_id, "shapes": shapes}, sort_keys=True)
return hashlib.sha256(payload.encode()).hexdigest()[:16]
def cache_path(self, model_id: str, shapes: dict) -> Path:
key = self._cache_key(model_id, shapes)
return Path(self.cache_dir) / f"{model_id}_{key}.stamp"
def is_warm(self, model_id: str, shapes: dict) -> bool:
stamp = self.cache_path(model_id, shapes)
warm = stamp.exists()
if warm:
self.cache_hits += 1
else:
self.cache_misses += 1
return warm
def record_warmup(self, model_id: str, shapes: dict, duration_s: float):
stamp = self.cache_path(model_id, shapes)
stamp.write_text(json.dumps({"model": model_id, "shapes": shapes, "compiled_at": time.time(), "duration_s": duration_s}))
self.warmup_completed = True
self.warmup_duration_s = duration_s
print(f"[JITCacheGuard] compiled + cached {model_id} shapes={shapes} in {duration_s:.1f}s")
def run_warmup_if_needed(self, model_id: str, model_fn, warmup_inputs: dict):
"""Call at process start to ensure the model is compiled before requests arrive."""
shapes = {k: list(v.shape) for k, v in warmup_inputs.items()}
if self.is_warm(model_id, shapes):
print(f"[JITCacheGuard] cache hit for {model_id} — skipping warmup")
return
print(f"[JITCacheGuard] cold cache for {model_id} — running warmup compilation")
start = time.perf_counter()
_ = model_fn(**warmup_inputs)
duration = time.perf_counter() - start
self.record_warmup(model_id, shapes, duration)
def assert_timeout_safe(self, request_timeout_s: float):
"""Warn if request timeout is shorter than typical cold-compile time."""
model_compile_estimates = {
"gemma-2b": 30.0,
"gemma-7b": 75.0,
"gemma-9b": 90.0,
"gemma-27b": 165.0,
}
max_cold = max(model_compile_estimates.values())
if request_timeout_s < max_cold and not self.warmup_completed:
print(
f"[JITCacheGuard] WARNING: request timeout {request_timeout_s}s is shorter than "
f"worst-case cold XLA compilation time {max_cold}s — first request to a cold "
"subprocess will timeout before compilation finishes. Run warmup or use "
"persistent cache."
)
# Startup sequence for any process serving Gemma + JAX requests
def init_gemma_process(model_id: str, model_fn, warmup_token_ids: list[int]):
import jax.numpy as jnp
configure_jax_persistent_cache()
guard = JITCacheGuard(cache_dir=JAX_CACHE_DIR)
guard.assert_timeout_safe(request_timeout_s=30.0)
warmup_inputs = {"token_ids": jnp.array([warmup_token_ids])}
guard.run_warmup_if_needed(model_id, model_fn, warmup_inputs)
return guard
Composite Guard for Gemma + JAX Agents
In practice, all four failure modes can compound in a single agent run. An agent that starts fresh subprocesses per request (failure mode 4), mixes numpy KV accumulators with JAX GPU tensors (failure mode 2), produces variable-length tool results (failure mode 1), and occasionally routes to Gemma-27B (failure mode 3) will have latency in the minutes per agent step on a cache-cold path — even on high-end hardware. The composite guard below combines all four protections into a single policy object:
import jax.numpy as jnp
from dataclasses import dataclass, field
from typing import Any, Optional
@dataclass
class GemmaJAXAgentPolicy:
available_vram_gb: float
cache_dir: str = "/tmp/jax_compilation_cache"
request_timeout_s: float = 60.0
max_unique_buckets: int = 6
def __post_init__(self):
configure_jax_persistent_cache(self.cache_dir)
self.xla_guard = XLARecompileGuard(model_id="gemma", max_unique_buckets=self.max_unique_buckets)
self.device_guard = DevicePlacementGuard(target_device="gpu:0")
self.eviction_guard = ModelEvictionGuard(available_vram_gb=self.available_vram_gb)
self.jit_guard = JITCacheGuard(cache_dir=self.cache_dir)
self.jit_guard.assert_timeout_safe(self.request_timeout_s)
def prepare_input(self, token_ids: list[int]) -> tuple[Any, int]:
padded, actual_len = pad_to_bucket(token_ids)
warning = self.xla_guard.check_and_record(len(padded))
if warning:
print(f"[GemmaJAXAgentPolicy] {warning}")
arr = jnp.array([padded]) # already on default device
return arr, actual_len
def pin_kv(self, kv_cache: Any) -> Any:
return self.device_guard.pin_to_device(kv_cache)
def switch_model(self, model_id: str, load_fn) -> Any:
return load_gemma_with_eviction(model_id, load_fn, self.eviction_guard)
def report(self) -> dict:
return {
"xla_buckets_compiled": len(self.xla_guard.compiled_buckets),
"xla_total_compile_s": self.xla_guard.total_compilation_s(),
"device_copies": self.device_guard.report(),
"models_evicted": self.eviction_guard.eviction_log,
"jit_cache_hits": self.jit_guard.cache_hits,
"jit_cache_misses": self.jit_guard.cache_misses,
}
# --- Example usage ---
# policy = GemmaJAXAgentPolicy(available_vram_gb=24.0)
#
# At the start of each agent step:
# token_input, actual_len = policy.prepare_input(accumulated_token_ids)
# kv = policy.pin_kv(kv_cache)
# params = policy.switch_model("gemma-2b", load_gemma_2b)
# output = model_forward(params, token_input, kv)
#
# At the end of a session:
# print(policy.report())
Failure Mode Summary
| Failure mode | Root cause | Typical cost | Guard |
|---|---|---|---|
XLA recompilation per unique shapeXLARecompileGuard |
JAX traces + compiles one kernel per distinct input shape. Agent produces a new shape every loop iteration. | 15–180s per unique length; 10 shapes = 150–1800s overhead across one agent run. | Pad all inputs to fixed length buckets (powers of 2). Track compiled bucket set; warn before triggering a new compilation. |
CPU-to-GPU tensor copies inside loopsDevicePlacementGuard |
Mixing numpy or CPU Python arrays with JAX GPU arrays triggers implicit full-array PCIe copies on every operation. | 1.8 GB KV cache × 20 loop iterations = 36 GB PCIe traffic; 2–3s additional latency on PCIe 4.0. | Pin all arrays to target device at loop boundary with jax.device_put. Check device attribute before every inference call. |
Model size switch without evictionModelEvictionGuard |
Loading Gemma-27B while Gemma-7B params are still resident exhausts VRAM. JAX does not auto-evict previous parameter pytrees. | OOM error or silent CPU fallback (40–60× slower) on the 2nd or 3rd model-switch call. | Explicitly delete previous params and call gc.collect() + jax.effects_barrier() before loading the next model. Enforce single-model-in-memory invariant. |
JIT cache loss from subprocess isolationJITCacheGuard |
JAX compiled kernels are in-process only. Agent frameworks that spawn a new subprocess per request start cold every time. | 30–165s cold compilation before the first inference in each subprocess. Timeout-triggered retries compound the cost. | Configure jax_compilation_cache_dir for persistent disk-backed XLA artifact cache. Run warmup generation at process start before accepting requests. |
Where does RunGuard fit? The GemmaJAXAgentPolicy above is a standalone implementation you can drop into any Gemma + JAX agent today. RunGuard's SDK extends this to cross-framework telemetry, a CI-visible dashboard of trips-per-app, and Slack/PagerDuty alerting when a guard fires — so the on-call engineer sees the XLA recompilation spiral before the end user sees a timeout.
Frequently asked questions
Does padding to fixed buckets reduce generation quality?
No. The padding tokens are added to the end of the input sequence and masked out by the attention mask — the model never attends to them during generation. Output quality is identical to the unpadded case. The trade-off is a small amount of wasted compute on the padded positions during the prefill pass (roughly proportional to (bucket_size - actual_length) / bucket_size), which is almost always less than 10% for well-chosen bucket boundaries.
Can JAX's persistent compilation cache be shared across multiple machines in a cluster?
Yes, with caveats. The XLA artifact files are serialized PTX or LLVM IR plus metadata. They are hardware-specific: an artifact compiled for an A100 will not load correctly on an H100 or a TPU v4. If your cluster has homogeneous GPU/TPU SKUs, placing the cache directory on shared NFS or a Cloud Storage FUSE mount lets all replicas share the warm cache. If SKUs are mixed, the cache must be keyed per hardware type — use subdirectories like jax_cache/a100/ and jax_cache/h100/ and set jax_compilation_cache_dir accordingly at startup based on the detected device type.
My agent uses Gemma via the Hugging Face transformers library with a JAX backend rather than the native Gemma JAX implementation. Do these failure modes still apply?
All four failure modes apply whenever JAX is the execution backend, regardless of the model-loading abstraction layer. Hugging Face's Flax models (the JAX/Flax variant of transformers) use jax.jit internally. You will see XLA recompilation on shape changes, device placement mismatches if you mix numpy and Flax arrays, and JIT cache loss on subprocess isolation — all for the same underlying reasons. The GemmaJAXAgentPolicy guards are framework-agnostic at the JAX level and apply directly.
When I call jax.effects_barrier() after deleting a model's params, I still see the old VRAM usage in nvidia-smi for several seconds. Is the eviction actually happening?
This is expected. jax.effects_barrier() ensures all pending JAX operations have completed (preventing use-after-free on the buffer), but the CUDA memory allocator may hold the freed memory in its own pool for re-use rather than immediately returning it to the OS. nvidia-smi reports the total pool size allocated to the process, not the amount currently in active use. The VRAM is logically free from JAX's perspective — a subsequent load_gemma_27b() call will succeed. If you need the OS-level VRAM to show free immediately (e.g., for observability dashboards), call torch.cuda.empty_cache() if PyTorch is also present, or use ctypes to call cudaDeviceSynchronize + the CUDA memory trim API explicitly.
Is there a way to run Gemma + JAX with fully dynamic shapes to avoid the padding overhead entirely?
JAX does support dynamic shapes via the jax.experimental.enable_x64 and polymorphic shape export paths, but as of mid-2026 the feature is still experimental and not supported for all XLA backends or all Gemma architectural variants (particularly the Gemma 2 GQA path). The production-safe recommendation is fixed-bucket padding as described above. If you experiment with dynamic shapes and find they work for your specific hardware and model variant, the trade-off is avoiding the padding overhead at the cost of potentially re-entering beta compiler paths that may produce slower or incorrect code on shape combinations that were not exercised in XLA's test suite.