A clean and modular deep learning template to predict Gestational Age (GA) from Blindsweep ultrasound datasets using PyTorch. Built for reproducibility, simplicity, and extensibility. ππ§
β Ready-to-use template Includes datasets, models, training loops, and inference pipelines for GA prediction.
β Attention-based video aggregation Uses a ResNet backbone and weighted average attention to handle multiple frames per sweep.
β Educational and reproducible Thoroughly commented code and modular dataset/model classes make it easy to learn and extend.
β Not optimized for very large datasets (can be adapted).
β GPU is recommended; CPU training will be slow.
- Handles single and multi-sweep datasets.
- Uses ResNet18 backbone with optional fine-tuning.
- Weighted average attention for frame aggregation.
- Training and validation with MAE loss.
- Saves best model automatically.
- TensorBoard integration for visualization.
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
echo 'eval "$($HOME/miniconda3/bin/conda shell.bash hook)"' >> ~/.bashrc
source ~/.bashrcconda create -n ga-us python=3.10
conda activate ga-uspip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
pip install notebook pandas tensorboard
conda install -c conda-forge nibabel
pip install tqdmmkdir checkpoints logs- Train dataset: single sweep per sample.
- Validation/Test dataset: multiple sweeps per sample.
- CSV files should include:
study_id, ga, path_nifti_1, path_nifti_2, ...
.
βββ .gitignore
βββ check.ipynb
βββ data.py
βββ evaluate_metrics.py
βββ infer.py
βββ model.py
βββ requirements.txt
βββ train.py
βββ checkpoints/
βββ best_model.pth
βββ logs/
βββ events.out.tfevents...
.gitignore: Git ignore file for unnecessary files.check.ipynb: Jupyter notebook for exploratory analysis or debugging.data.py: Dataset loading and preprocessing logic.evaluate_metrics.py: Code to evaluate model performance.infer.py: Inference logic for prediction.model.py: Contains model architecture (ResNet backbone).requirements.txt: List of required Python packages.train.py: Training loop and validation logic.checkpoints/: Directory to store saved model checkpoints.logs/: Directory to store TensorBoard logs.
- Backbone: ResNet18 (pretrained, optionally frozen).
- Attention: WeightedAverageAttention for frame aggregation.
- Output: Linear layer predicting GA.
Improve the model as required.
Run training using:
from train import train_and_validate
train_csv = "path/to/train.csv"
val_csv = "path/to/val.csv"
train_and_validate(train_csv, val_csv, epochs=100, batch_size=8, n_sweeps_val=8, save_path='checkpoints/best_model.pth')- Uses MAE (L1) loss. The code can be adapted to use other loss functions.
- Saves best model automatically to
checkpoints/best_model.pth.
You can visualize the training process with TensorBoard. To log training metrics and visualize them:
-
Start TensorBoard on the server:
Run the following command on the server where your model is training. This will start the TensorBoard service:
tensorboard --logdir=logs --port=6000
--port=6000specifies the port to use (default is 6006).
-
Map the server port to your local machine:
If you're connecting to the server remotely via SSH, you'll need to forward the TensorBoard port so you can access it locally. In your terminal (on your local machine), run:
ssh -L -p 5555 6000:localhost:6000 user@server_ip
Replace
user@server_ipwith your actual username and server IP. This command forwards the server's port 6006 (where TensorBoard is running) to your local machine's port 6006. -
Open TensorBoard in your browser:
Once the SSH tunnel is established, open your browser and navigate to
http://localhost:6006to view TensorBoard and monitor metrics like loss, accuracy, etc.
To predict GA from test data:
from infer import predict_ga
model_path = "checkpoints/best_model.pth"
test_csv = "path/to/test.csv"
predictions = predict_ga(model_path, test_csv)Hereβs a sample of how your inference CSV should look:
study_id, site, predicted_ga
KA-PC-002-1, Kenya, 180
NL-PC-087-1, Nepal, 157
PN-PC-090-1, Pakiastan, 223
- PyTorch
- TorchVision
- NiBabel for NIfTI image handling
MIT License.