Skip to content

MuXauJl11110/iJKOnet

Repository files navigation

Learning of Population Dynamics:
Inverse Optimization Meets JKO Scheme

Mikhail Persiianov, Jiawei Chen, Petr Mokrov, Alexander Tyurin, Evgeny Burnaev, Alexander Korotin

arXiv Paper OpenReview Paper GitHub GitHub License

This repository contains the official implementation of the paper "Learning of Population Dynamics: Inverse Optimization Meets JKO Scheme", accepted at ICLR 2026. The goal of this work is to learn population-level dynamics from snapshot observations using the theory of Wasserstein gradient flows and optimal transport.

📌 TL;DR

This paper introduces $\texttt{iJKOnet}$, a method for learning population dynamics from unpaired snapshot distributions. It recovers an energy functional governing a Wasserstein gradient flow via an inverse formulation of the JKO scheme, resulting in a min–max objective between the energy model and transport maps. The approach avoids inner JKO solvers and precomputed optimal transport plans, enabling end-to-end training with standard neural networks.

📦 Project Structure

.
├── configs/               # training configuration files
│   ├── config-base.yaml
│   └── config-method-*.yaml
│
├── data/                  # datasets
├── models/                # JKOnet, JKOnet*, iJKOnet implementations
├── networks/              # neural network architectures
│
├── notebooks/             # tutorial notebooks
│
├── scripts/               # experiment scripts
│   ├── bash/              # generated bash scripts
│   ├── generate_sc_w_lo.py
│   ├── generate_sc_wo_lo.py
│   └── optuna_search.sh
│
├── utils/                 # helper utilities
│   ├── dataset/
│   ├── entropy_estimation/
│   └── evaluation/
│
├── train.py               # main training entrypoint
└── optuna_search.py       # hyperparameter search

📥 Installation & Dependencies

Install the necessary packages for this repository by creating the Anaconda environment:

git clone https://github.com/.../iJKOnet.git
cd iJKOnet

conda env create -f environment.yml
conda activate ijkonet  # Replace with actual environment name if different

📔 Notebooks

The notebooks folder contains example notebooks demonstrating how to generate data and train models. In particular, notebooks/iJKOnet_usage.ipynb demonstrates how to:

  • generate synthetic 2D data for learning a potential energy function,

  • train the available models on this data.

📜 Scripts

The scripts

  • generate_sc_w_lo.py
  • generate_sc_wo_lo.py

generate bash scripts (saved in scripts/bash) for running single-cell experiments:

leave-one-out experiments
full trajectory experiments (no leave-one-out)

The script scripts/optuna_search.sh is used to launch a SLURM job for hyperparameter search.

🧬 Learning Single-Cell Dynamics

For the EB dataset, we followed the preprocessing pipeline described in the $\texttt{JKOnet}^\star$ tutorial.

For the Multi dataset, we used the preprocessing pipeline described in the paper “A Computational Framework for Solving Wasserstein Lagrangian Flows” and implemented in the corresponding repository.

🏋️ Training

To start training, use the following general pattern:

python train.py \
  --solver <solver_name> \
  --dataset <dataset_name> \
  --config <path_to_base_config> \
  --extra_config <path_to_additional_config> \
  --K <K> \
  --array-tau <tau_or_tau_list> \
  --epochs <number_of_epochs> \
  --seed <seed>

🎓 Citation

If you find this repository useful in your research, please cite:

@inproceedings{
  persiianov2026learning,
  title={Learning of Population Dynamics: Inverse Optimization Meets {JKO} Scheme},
  author={Mikhail Persiianov and Jiawei Chen and Petr Mokrov and Alexander Tyurin and Evgeny Burnaev and Alexander Korotin},
  booktitle={The Fourteenth International Conference on Learning Representations},
  year={2026},
  url={https://openreview.net/forum?id=tVJIKd6CLF}
}

🙏 Credits

  • jkonet-star — our code is primarily based on this repository with some fixes of the data generation for synthetic experiments;
  • mutinfo — the code in utils/entropy_estimation is based on code from this repository;
  • POT and ott-jax for ot toolkit;
  • optuna - toolkit for hyperparameter search
  • comet ML — experiment-tracking and visualization toolkit;
  • inkscape — an excellent open-source editor for vector graphics;