ArXiv TLDR

Clin-JEPA: A Multi-Phase Co-Training Framework for Joint-Embedding Predictive Pretraining on EHR Patient Trajectories

🐦 Tweet
2605.10840

Yixuan Yang, Mehak Arora, Ryan Zhang, Baraa Abed, Junseob Kim + 8 more

cs.LGcs.AIq-bio.QM

TLDR

Clin-JEPA is a multi-phase co-training framework for JEPA pretraining on EHR patient trajectories, enabling accurate forecasting and risk prediction.

Key contributions

  • Introduces Clin-JEPA, a five-phase co-training framework for stable JEPA pretraining on EHR patient trajectories.
  • Achieves stable latent trajectory forecasting, with 15.7% rollout drift convergence over 48 hours.
  • Learns a clinically discriminative latent space, showing 4.83x greater displacement for deteriorating patients.
  • Outperforms baselines on multi-task risk prediction, achieving AUROC 0.883 on 8 binary tasks.

Why it matters

Clin-JEPA addresses the challenge of applying JEPA to EHR data for patient trajectory forecasting and risk prediction. It enables a single model to perform diverse tasks without fine-tuning, which is a significant step towards more versatile and efficient clinical AI.

Original Abstract

We present Clin-JEPA, a multi-phase co-training framework for joint-embedding predictive (JEPA) pretraining on EHR patient trajectories. JEPA architectures have enabled latent-space planning in robotics and high-quality representation learning in vision, but extending the paradigm to EHR data -- to obtain a single backbone that simultaneously forecasts patient trajectories and serves diverse downstream risk-prediction tasks without per-task fine-tuning -- remains an open challenge. Existing JEPA frameworks either discard the predictor after pretraining (I-JEPA, V-JEPA) or train it on a frozen pretrained encoder (V-JEPA 2-AC), leaving the encoder unaware of the rollout signal that the retained predictor must use at inference; co-training the encoder and predictor under a shared JEPA prediction objective would supply this grounding, but naïve co-training is unstable, with representation collapse and online/target drift causing autoregressive rollout to diverge. Clin-JEPA's five-phase pretraining curriculum -- predictor warmup, joint refinement, EMA target alignment, hard sync, and predictor finalization -- addresses each failure mode by phase, stably co-training a Qwen3-8B-based encoder and a 92M-parameter latent trajectory predictor. On MIMIC-IV ICU data, three independent evaluations support the framework: (1) latent $\ell_1$ rollout drift uniquely converges ($-$15.7%) over 48-hour horizons while baselines and ablations diverge (+3% to +4951%); (2) the encoder learns a clinically discriminative latent geometry (deteriorating-patient cohorts displace 4.83$\times$ further than stable patients in latent space, vs $\leq$2.62$\times$ for baseline encoders); (3) a single backbone outperforms strong tabular and sequence baselines on multi-task downstream evaluation. Clin-JEPA achieves mean AUROC 0.851 on ICareFM EEP and 0.883 on 8 binary risk tasks (+0.038 and +0.041 vs baseline average).

📬 Weekly AI Paper Digest

Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.