diff --git a/tutorial/detection/spheres/detection_spheres.ipynb b/tutorial/detection/spheres/detection_spheres.ipynb index c1fd0a5..d28323b 100644 --- a/tutorial/detection/spheres/detection_spheres.ipynb +++ b/tutorial/detection/spheres/detection_spheres.ipynb @@ -1815,7 +1815,7 @@ "metadata": {}, "source": [ "Define the training pipeline using Deeptrack to generate the training dataset\n", - "The training pipeline is composed of the simulated images and the probability maps, together with instructions of value normalization and a selector of images, all of which are instances of Deeplay." + "The training pipeline is composed of the simulated images and the probability maps, together with instructions of value normalization and a selector of images." ] }, { diff --git a/tutorial/linking/spheres/UNet/UNet_model_spheres.pth b/tutorial/linking/spheres/UNet/UNet_model_spheres.pth new file mode 100644 index 0000000..14664ff Binary files /dev/null and b/tutorial/linking/spheres/UNet/UNet_model_spheres.pth differ diff --git a/tutorial/linking/spheres/UNet/UNet_reg_spheres.pth b/tutorial/linking/spheres/UNet/UNet_reg_spheres.pth new file mode 100644 index 0000000..d8076c3 Binary files /dev/null and b/tutorial/linking/spheres/UNet/UNet_reg_spheres.pth differ diff --git a/tutorial/linking/spheres/UNet/training_data/UNet_training_dataset_spheres.npz b/tutorial/linking/spheres/UNet/training_data/UNet_training_dataset_spheres.npz new file mode 100644 index 0000000..4f9cd96 Binary files /dev/null and b/tutorial/linking/spheres/UNet/training_data/UNet_training_dataset_spheres.npz differ diff --git a/tutorial/linking/spheres/linking_with_unet.ipynb b/tutorial/linking/spheres/linking_with_unet.ipynb new file mode 100644 index 0000000..f3c9df6 --- /dev/null +++ b/tutorial/linking/spheres/linking_with_unet.ipynb @@ -0,0 +1,2574 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aarondomenzain/tracking-softmatter-aarond/blob/tracking-softmatter-aarond-dev/tutorial/tracking/tracking_spheres.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K3hcsWD9QOZQ" + }, + "source": [ + "# Particle Tracking Tutorial: Trajectory Linking\n", + "\n", + "In this tutorial, you’ll explore different methods to link particle localizations across time to recostruct trajectories using both simulated data and real experimental images.\n", + "\n", + "You’ll start by generating a simulated movie of microscopic particles undergoing Brownian motion, mimicking what you might see in a soft matter or biophysics experiment. For each frame, you'll use a U-net, a supervised neural network, to detect and localize particles. Then comes the core challenge: linking localization into trajectories.\n", + "\n", + "Here you’ll test and compare various methods to perform the linking of the trajectories:\n", + "\n", + "- Nearest-neighbor linking (using TrackPy — a classic in particle tracking)\n", + "\n", + "- Linear Assignment Problem (LAP) (using LapTrack - a more flexible and general framework)\n", + "\n", + "- MAGIK (a geometric deep learning method based on graph neural networks)\n", + "\n", + "You’ll be using Python libraries like NumPy, SciPy, Matplotlib, scikit-image, PyTorch, DeepTrack, and Deeplay. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table of Contents\n", + "\n", + "0. [Importing the Required Libraries and Loading Utility Functions](#importing-the-required-libraries-and-loading-utility-functions)\n", + "1. [Loading and Visualizing Experimental Videos](#loading-and-visualizing-experimental-videos)\n", + "2. [Simulating Realistic Videos with DeepTrack](#simulating-realistic-videos-with-deeptrack)\n", + " - [Simulating a Single Particle](#simulating-a-single-particle)\n", + " - [Simulating a Video Frame](#simulating-a-video-frame)\n", + " - [Simulating Brownian Trajectories](#simulating-brownian-trajectories)\n", + " - [Simulating a Video](#simulating-a-video)\n", + "\n", + "2. [Detecting and Localizing Particles with U-net](#detecting-and-localizing-particles-with-u-net)\n", + " - [Training U-net with Experiments](#training-u-net-with-experiments)\n", + " - [Evaluating U-net on Simulations](#evaluating-u-net-on-simulations)\n", + " - [Applying U-net to Simulations](#applying-u-netmulations)\n", + " - [Applying U-net to Experiments](#applying-u-net-to-experiments)\n", + "\n", + "3. [Method 1: Nearest-neighbor Linking with TrackPy](#method-1-nearest-neighbor-linking-with-trackpy)\n", + " - [Linking Localizations in Simulations](#linking-localizations-in-simulations)\n", + " - [Evaluating Linking Performance](#evaluating-linking-performance)\n", + " - [Linking Localizations in Experiments](#linking-localizations-in-experiments)\n", + "\n", + "4. [Method 2: Linear Assignment Problem (LAP) with LapTrack](#method-2-linear-assignment-problem-lap-with-laptrack)\n", + " - [Linking Localizations in Simulations](#linking-localizations-in-simulations)\n", + " - [Evaluating Linking Performance](#evaluating-linking-performance)\n", + " - [Linking Localizations in Experiments](#linking-localizations-in-experiments)\n", + "\n", + "5. [Method 3: MAGIK](#method-3-magik)\n", + " - [Training MAGIK with Simulations](#training-magik-with-simulations)\n", + " - [Linking Localizations in Simulations](#linking-localizations-in-simulations)\n", + " - [Evaluating Linking Performance](#evaluating-linking-performance)\n", + " - [Linking Localizations in Experiments](#linking-localizations-in-experiments)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Importing the Required Libraries and Loading Utility Functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Uncomment the next cell if running on Google Colab/Kaggle." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7MXGmUjyD-Ve" + }, + "outputs": [], + "source": [ + "#!pip install deeptrack deeplay trackpy laptrack -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WadH3picQOZT" + }, + "outputs": [], + "source": [ + "# Standard libraries.\n", + "import logging\n", + "import os\n", + "import random\n", + "import sys\n", + "\n", + "# Configuration\n", + "import matplotlib\n", + "matplotlib.rcParams[\"animation.embed_limit\"] = 60 # Larger animations inline\n", + "logging.disable(logging.WARNING) # Suppress warnings and below\n", + "\n", + "# Core Scientific Stack\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "# Plotting and Display\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Rectangle\n", + "\n", + "# Machine Learning\n", + "import deeplay as dl\n", + "from sklearn.metrics import f1_score\n", + "import torch\n", + "from torchvision.transforms import Compose\n", + "from torch_geometric.loader import DataLoader\n", + "\n", + "# Particle Tracking and Simulation\n", + "import deeptrack as dt\n", + "from laptrack import LapTrack\n", + "import trackpy as tp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load a set of custom functions defined specifically for this notebook from the `utils` directory. For detailed documentation of each function, refer to the comments and docstrings within the files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load functions and utilities for dataset generation and visualization.\n", + "# Sys append a folder to the path.\n", + "sys.path.append(os.path.abspath(os.path.join(\"..\", \"..\")))\n", + "\n", + "# Import all the functions contained in the folder utils.\n", + "import utils" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set random seeds to make results reproducible across runs, especially during training and data simulation. Also select the best device for computations with Torch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set a fixed seed value.\n", + "seed = 98\n", + "\n", + "# Python, NumPy, and PyTorch (CPU).\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n", + "\n", + "# Only set CUDA seeds if a GPU is available.\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + "print(f\"Seeds set to {seed} (with CUDA: {torch.cuda.is_available()})\")\n", + "\n", + "# Get the best available device for Torch computation.\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + " print(f\"Using CUDA GPU: {torch.cuda.get_device_name(0)}\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + " print(\"Using Apple GPU (MPS)\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + " print(\"Using CPU\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-lhy7J2ePRQw" + }, + "source": [ + "## Loading and Visualizing Experimental Videos" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o2MF8ohRFKZV" + }, + "source": [ + "You'll use experimental video of a system of colloidal particles recorded with fluorescence microscopy. Data from https://www.nature.com/articles/s41467-022-30497-z.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 428 + }, + "id": "ico6y7gkFKZV", + "outputId": "752d216d-df5b-4397-a552-9106b11375d2" + }, + "outputs": [], + "source": [ + "# Define the folder and video file name.\n", + "video_folder = \"videos\"\n", + "video_file_name = \"experimental_video.npy\"\n", + "\n", + "# Construct the full path.\n", + "video_path = os.path.join(video_folder, video_file_name)\n", + "\n", + "# Load the video data.\n", + "exp_video = np.load(video_path)\n", + "\n", + "utils.play_video(exp_video, \"Experimental Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the first frame of the video; then manually select and display a single particle by specifying its centroid coordinates (x, y) and a box width." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select the first frame of the video.\n", + "exp_image = exp_video[0]\n", + "\n", + "# Get the shape of the image.\n", + "assert exp_image.shape[0] == exp_image.shape[1], \"Warning: Image not square!\"\n", + "exp_image_size = exp_image.shape[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the parameters for the viewing box." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Box size to zoom in an individual particle.\n", + "exp_crop_size = 15\n", + "\n", + "# Coordinates of the center of the particle.\n", + "x_center = 96\n", + "y_center = 41" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Calculate top-left corner of the crop.\n", + "x = x_center - exp_crop_size // 2\n", + "y = y_center - exp_crop_size // 2\n", + "\n", + "# Select a crop as a subset of the entire image.\n", + "exp_crop = exp_image[\n", + " y:y + exp_crop_size, # row (y)\n", + " x:x + exp_crop_size, # column (x)\n", + "]\n", + "\n", + "# Initialize figure instance.\n", + "fig = plt.figure()\n", + "\n", + "vmin, vmax = np.percentile(exp_image, [1, 99])\n", + "\n", + "# Draw a red rectangle around the crop.\n", + "fig.add_subplot(111)\n", + "plt.imshow(exp_image, cmap=\"gray\", vmin=vmin, vmax=vmax)\n", + "plt.title(\"Experimental Image\", size=13)\n", + "plt.plot([x, x, x + exp_crop_size, x + exp_crop_size, x],\n", + " [y, y + exp_crop_size, y + exp_crop_size, y, y], 'r-')\n", + "plt.axis(\"off\")\n", + "\n", + "# Plot the rectangle on the top right corner.\n", + "fig.add_subplot(555)\n", + "plt.imshow(exp_crop, cmap=\"gray\", vmin=vmin, vmax=vmax)\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simulating Realistic Videos with DeepTrack" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "DeepTrack allows you to simulate physically realistic microscopy images and videos, enabling precise control over imaging parameters and particle properties. These simulations provide ground-truth data, making them ideal for benchmarking classical and AI-based tracking methods, as well as for training neural networks in a controlled environment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Simulating a Single Particle" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Adjust the simulation parameters to accurately replicate the features observed in the cropped region of the experimental image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Same as the box width.\n", + "sim_crop_size = exp_crop_size \n", + "\n", + "# Size of a pixel in nanometers in the output image.\n", + "pixel_size_nm = 640 # In nm.\n", + "\n", + "# Radius of the particle.\n", + "particle_radius = 440 # In nm.\n", + "\n", + "# Define central spherical scatterer.\n", + "sphere = dt.Sphere(\n", + " position=0.5 * np.array([sim_crop_size, sim_crop_size]) + (-0.5, 0.5),\n", + " z= 400 * dt.units.nm, # Particle in focus.\n", + " radius= particle_radius * dt.units.nm, # Radius in nm\n", + " intensity= 1E5, # Field magnitude squared\n", + " refractive_index=1.59,\n", + ")\n", + "\n", + "# Simulate the properties of the fluorescence microscope.\n", + "optics = dt.Fluorescence(\n", + " NA=0.3, # Numerical aperture\n", + " wavelength=508 * dt.units.nm,\n", + " refractive_index_medium=1.33,\n", + " output_region=[0, 0, sim_crop_size, sim_crop_size],\n", + " magnification=2.6,\n", + " resolution=pixel_size_nm * dt.units.nm, # Camera effective resolution\n", + ")\n", + "\n", + "# Apply transformations.\n", + "sim_crop = (\n", + " optics(sphere)\n", + " >> dt.Background(750) # Background intensity level\n", + " >> dt.Poisson(snr=6000) # Signal-to-noise ratio (SNR) of the image\n", + ")\n", + "\n", + "# Turn the crop into a NumPy array.\n", + "sim_crop = np.squeeze(sim_crop())\n", + "\n", + "# Plot the simulated and experimental crops.\n", + "fig, axes = plt.subplots(1, 2)\n", + "\n", + "# Simulated crop.\n", + "plot = axes[0].imshow(sim_crop, cmap=\"gray\")\n", + "axes[0].axis(\"off\")\n", + "axes[0].set_title(\"Simulated Crop\")\n", + "\n", + "# Experimental crop.\n", + "axes[1].imshow(exp_crop, cmap=\"gray\") \n", + "axes[1].axis(\"off\")\n", + "axes[1].set_title(\"Experimental Crop\")\n", + "\n", + "# Adjust layout and show plot.\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Extract and visualize the raw intensity profiles along the central horizontal and vertical lines of both the simulated and experimental crops. This comparison helps evaluate how well the simulation reproduces the intensity distribution observed in real microscopy images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute the center index.\n", + "center = sim_crop_size // 2\n", + "\n", + "# Extract horizontal (row) profiles.\n", + "sim_horiz = sim_crop[center, :]\n", + "exp_horiz = exp_crop[center, :]\n", + "\n", + "# Extract vertical (column) profiles.\n", + "sim_vert = sim_crop[:, center]\n", + "exp_vert = exp_crop[:, center]\n", + "\n", + "# Create a 1×2 subplot.\n", + "fig, axes = plt.subplots(1, 2, figsize=(12, 4), tight_layout=True,sharey=True)\n", + "\n", + "# --- Horizontal profile ---\n", + "axes[0].plot(sim_horiz, label=\"Simulated Crop\", color=\"orange\")\n", + "axes[0].plot(exp_horiz, label=\"Experimental Crop\", color=\"blue\")\n", + "axes[0].set_xlabel(\"Pixel (x)\")\n", + "axes[0].set_ylabel(\"Intensity\")\n", + "axes[0].set_title(\"Horizontal Intensity Profile (Center Row)\")\n", + "axes[0].legend()\n", + "axes[0].grid(True, linestyle=\"--\", alpha=0.5)\n", + "\n", + "# --- Vertical profile ---\n", + "axes[1].plot(sim_vert, label=\"Simulated Crop\", color=\"orange\")\n", + "axes[1].plot(exp_vert, label=\"Experimental Crop\", color=\"blue\")\n", + "axes[1].set_xlabel(\"Pixel (y)\")\n", + "axes[1].set_ylabel(\"Intensity\")\n", + "axes[1].set_title(\"Vertical Intensity Profile (Center Column)\")\n", + "axes[1].legend()\n", + "axes[1].grid(True, linestyle=\"--\", alpha=0.5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Simulating a Video Frame" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a simulated image containing non-overlapping spherical particles. Begin by generating their coordinates to serve as the ground-truth positions. Then, use DeepTrack to render optically realistic particles at these coordinates, resulting in a physically plausible microscopy image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parameters of the simulation.\n", + "sim_image_size = 256\n", + "N_particles = 40\n", + "particle_radius = 440 # Particle radius in nm\n", + "\n", + "# Dictionary for particle properties. Dimensions are set with lambda\n", + "# functions to introduce variety to the dataset.\n", + "sphere_properties = {\n", + " \"intensity\": lambda: np.random.uniform(1.0, 2.0) * 0.85E5,\n", + " \"z\": lambda: np.random.uniform(-1000, -20000) * dt.units.nm,\n", + " \"radius\": particle_radius * dt.units.nm,\n", + " \"refractive_index\": 1.59,\n", + "}\n", + "# Set the optical properties of the microscope. This dictionary is a DeepTrack\n", + "# optics object.\n", + "optics_properties = dt.Fluorescence(\n", + " NA=0.4, # Numerical aperture\n", + " wavelength=508 * dt.units.nm,\n", + " refractive_index_medium=1.33,\n", + " output_region=[0, 0, sim_image_size, sim_image_size],\n", + " magnification=2.5,\n", + " resolution=640 * dt.units.nm, # Camera effective resolution\n", + ")\n", + "# Generate ground truth positions. \n", + "sim_gt_pos = utils.generate_centroids(\n", + " num_particles=N_particles,\n", + " fov_size=sim_image_size,\n", + " particle_radius=particle_radius,\n", + ")\n", + "# Simulate image.\n", + "sim_image = utils.transform_to_video(\n", + " sim_gt_pos,\n", + " fov_size=sim_image_size,\n", + " core_particle_props=sphere_properties,\n", + " optics_props=optics_properties,\n", + " background_props={\"poisson_snr\":6500, \"background_mean\": 650},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Visualize the simulated image and compare it with the experimental one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the simulated and experimental images.\n", + "fig, axes = plt.subplots(1, 2)\n", + "vmin, vmax = np.percentile(exp_image, [1, 99])\n", + "\n", + "# Simulated image.\n", + "axes[0].imshow(sim_image, cmap=\"gray\", vmin=vmin, vmax=vmax)\n", + "axes[0].axis(\"off\")\n", + "axes[0].set_title(\"Simulated Image\")\n", + "axes[0].scatter(sim_gt_pos[:, 1], sim_gt_pos[:, 0], marker=\".\", color=\"red\", s=10, label=\"Ground truth positions\")\n", + "axes[0].legend(loc=\"upper left\", markerscale=5)\n", + "\n", + "# Experimental image.\n", + "axes[1].imshow(exp_image, cmap=\"gray\", vmin=vmin, vmax=vmax) \n", + "axes[1].axis(\"off\")\n", + "axes[1].set_title(\"Experimental Image\")\n", + "\n", + "# Adjust layout and show plot.\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Perform a more quantitative comparison by plotting the intensity histograms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Flatten the arrays to 1D.\n", + "sim_vals = sim_image.ravel()\n", + "exp_vals = exp_image.ravel()\n", + "\n", + "# Compute common bin edges.\n", + "all_vals = np.concatenate([sim_vals, exp_vals])\n", + "num_bins = 60\n", + "bins = np.linspace(all_vals.min(), np.quantile(all_vals, 0.99), num_bins + 1)\n", + "\n", + "# Create figure with two subplots sharing axes.\n", + "fig, axes = plt.subplots(\n", + " 1, 2, \n", + " figsize=(10, 4), \n", + " sharey=True, \n", + " sharex=True, \n", + " tight_layout=True\n", + ")\n", + "\n", + "# Histogram for simulated image.\n", + "axes[0].hist(sim_vals, bins=bins, alpha=0.7, edgecolor=\"black\")\n", + "axes[0].set_title(\"Simulated Image Histogram\")\n", + "axes[0].set_xlabel(\"Intensity\")\n", + "axes[0].set_ylabel(\"Pixel Count\")\n", + "\n", + "# Histogram for experimental image.\n", + "axes[1].hist(exp_vals, bins=bins, alpha=0.7, edgecolor=\"black\")\n", + "axes[1].set_title(\"Experimental Image Histogram\")\n", + "axes[1].set_xlabel(\"Intensity\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L_6ZSzxmQOZU" + }, + "source": [ + "### Simulating Brownian Trajectories" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You'll simulate a set of trajectories that visually resemble the experimental data to evaluate the performance of different tracking methods. The goal is to replicate the Brownian motion of nanoparticles as observed in the experimental videos. This is done using the `simulate_Brownian_trajs()` function from the utility file, which generates 2D trajectories based on a random walk model. Refer to the function for more details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CQUXCeWaQOZW" + }, + "outputs": [], + "source": [ + "# Simulation parameters.\n", + "number_particles = 30\n", + "number_timesteps = 50\n", + "\n", + "# Simulate trajectories for one video.\n", + "sim_trajs_gt = utils.simulate_Brownian_trajs(\n", + " num_particles=number_particles,\n", + " num_timesteps=number_timesteps,\n", + " fov_size=sim_image_size,\n", + " diffusion_std=0.5, # Corresponds to sqrt(2Dt)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For evaluation purposes, trajectories that move out and back in the field of view due to boundary conditions are treated as separate trajectories. For further analysis, the trajectories are transformed into a list." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Break trajectories going in/out of FOV.\n", + "sim_trajs_gt_list = utils.traj_break(\n", + " trajs=sim_trajs_gt,\n", + " fov_size=sim_image_size,\n", + " num_particles=sim_trajs_gt.shape[1],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zE3jXw3MQOZX" + }, + "source": [ + "### Simulating a Video" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You'll now simulate a video of particle motion that resembles experimental data and compares the two by playing them simultaneously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "gs3F2pQBQOZX", + "outputId": "bb2c3b9d-ef96-4f90-9569-92ea4fea1a4b" + }, + "outputs": [], + "source": [ + "# Simulate video.\n", + "sim_video = utils.transform_to_video(\n", + " np.delete(sim_trajs_gt, 2, 2), # Remove frame axis\n", + " fov_size=sim_image_size,\n", + " core_particle_props=sphere_properties,\n", + " optics_props=optics_properties,\n", + " background_props={\"poisson_snr\": 9500, \"background_mean\": 650},\n", + " save_video=True,\n", + " path=\"videos/simulated_video.tiff\",\n", + ")\n", + "\n", + "# Play both simulated and experimental videos and compare.\n", + "utils.play_video(sim_video, \"Simulated Video\")\n", + "utils.play_video(exp_video, \"Experimental Video\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** Several out-of-focus particles are rather dim and not clearly visible unless the dynamic range of the video is adjusted. However, including them is essential to accurately reproduce the experimental conditions, especially in terms of intensity distribution and background noise characteristics." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9KsF77OGQOZY" + }, + "source": [ + "## Detecting and Localizing Particles with U-net" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A U-net will be used to detect the position of particles at each frame of the video, similarly as shown in the **Detections** notebooks of this tutorial. These positions will be passed to different linking methods to build trajectories and compare their performance." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training a U-net with Simulated Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Implement a simulation pipeline, as shown in the Detection notebooks, to generate a training dataset for a U-Net." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Number of samples, image size, and particles.\n", + "num_samples = 256\n", + "train_image_size = 128\n", + "max_num_particles = 10\n", + "force_simulation = True # Flag to force simulation even if data exists\n", + "\n", + "# Optical properties of spheres with variability.\n", + "sphere_properties = {\n", + " \"intensity\": lambda: np.random.uniform(1.0, 2.0) * 0.85E5, \n", + " \"z\": lambda: np.random.uniform(-4500, -25000) * dt.units.nm,\n", + " \"radius\": particle_radius * dt.units.nm,\n", + " \"refractive_index\": 1.59,\n", + "}\n", + "\n", + "# Set the optical properties of the microscope.\n", + "optics_properties = dt.Fluorescence(\n", + " NA=lambda: np.random.uniform(0.4, 0.6), # Numerical aperture with variability\n", + " wavelength=508 * dt.units.nm,\n", + " refractive_index_medium=1.33,\n", + " output_region=[0, 0, train_image_size, train_image_size],\n", + " magnification= lambda: np.random.uniform(2.0, 3.5),\n", + " resolution=640 * dt.units.nm, # Camera effective resolution\n", + ")\n", + "\n", + "# Create path to store training dataset.\n", + "folder_name = \"UNet\"\n", + "training_dataset_filename = \"UNet_training_dataset_spheres.npz\"\n", + "training_dataset_folder = os.path.join(folder_name, \"training_data\")\n", + "training_dataset_filepath = os.path.join(\n", + " training_dataset_folder, \n", + " training_dataset_filename,\n", + ")\n", + "\n", + "# Create the enclosing directory if not existent already.\n", + "if not os.path.exists(training_dataset_folder):\n", + " os.makedirs(training_dataset_folder, exist_ok=True)\n", + "\n", + "# Try to load preexisting data, if not available or forced, raise an exception\n", + "# error to generate new data.\n", + "try:\n", + " if force_simulation: \n", + " # Raise the exception error if simulation is forced.\n", + " raise FileNotFoundError(\"Forced simulation by user request.\")\n", + " \n", + " if not os.path.isfile(training_dataset_filepath):\n", + " # If file is not found, start training. \n", + " raise FileNotFoundError(\n", + " \"Training dataset file not found. Starting simulation.\"\n", + " )\n", + " \n", + " # Load existing data\n", + " data = np.load(training_dataset_filepath)\n", + " images = data[\"images\"]\n", + " print(images.shape)\n", + " maps = data[\"maps\"]\n", + " Nsamples = len(images)\n", + " print(f\"Loaded file: {training_dataset_filepath}\")\n", + " \n", + "# Handle the case of either file not found or forced training.\n", + "except FileNotFoundError:\n", + " # Generate new dataset if file not found or simulation is forced.\n", + " images, maps = utils.generate_particle_dataset(\n", + " num_samples = num_samples,\n", + " fov_size = train_image_size,\n", + " max_num_particles = max_num_particles,\n", + " core_particle_dict=sphere_properties,\n", + " optics_properties=optics_properties,\n", + " background_props={\"poisson_snr\": 6500, \"background_mean\": 2500},\n", + " )\n", + " \n", + " # Save the simulated training dataset.\n", + " np.savez(training_dataset_filepath, images=images, maps=maps)\n", + " print(f\"Training dataset saved in: {training_dataset_filepath}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2tUtpRtJQOZY" + }, + "outputs": [], + "source": [ + "# Select an image and its corresponding probability maps and mask to show.\n", + "selected_image_index = np.random.randint(0, len(images))\n", + "\n", + "# Extract the image and probability map from 4D arrays.\n", + "selected_image = np.squeeze(images[selected_image_index])\n", + "selected_probability_map = np.squeeze(maps[selected_image_index])\n", + "\n", + "# Plot the image as the first subplot.\n", + "utils.plot_image_mask_ground_truth_map(\n", + " image=selected_image,\n", + " gt_map=selected_probability_map,\n", + " title=f\"Training dataset element: {selected_image_index+1}/{len(images)}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a U-net model and a regressor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unet = dl.UNet2d(\n", + " in_channels=1, \n", + " channels=[32, 64, 128, 256, 512], \n", + " out_channels=1,\n", + ")\n", + "regressor_unet = dl.Regressor(\n", + " model=unet, \n", + " loss=torch.nn.MSELoss(), \n", + " optimizer=dl.AdamW(),\n", + ").create()\n", + "\n", + "# Image selector with a random picker. This is performed in order to properly \n", + "# link an element in maps array with its corresponding element in images \n", + "# array.\n", + "selector = dt.Lambda(\n", + " lambda i: lambda x: x[i], i=lambda l: np.random.randint(l), l=len(images)\n", + ")\n", + "\n", + "# Apply augmentations of added Gaussian noise only to images.\n", + "images_augmentation_pipeline = (\n", + " dt.Value(images)\n", + " >> dt.NormalizeMinMax(0.0, 1.0)\n", + " >> dt.Gaussian(0, 0.002)\n", + " >> dt.Poisson(snr=50)\n", + " >> dt.NormalizeMinMax(\n", + " lambda: np.random.uniform(0.0, 0.1), \n", + " lambda: np.random.uniform(0.9, 1.0),\n", + " )\n", + ")\n", + "\n", + "maps_pipeline = dt.Value(maps) >> dt.NormalizeMinMax(0.0, 1.0)\n", + "\n", + "pipeline = (\n", + " (images_augmentation_pipeline & maps_pipeline)\n", + " >> selector\n", + " # >> dt.FlipUD()\n", + " # >> dt.FlipLR()\n", + " >> dt.MoveAxis(-1, 0)\n", + " >> dt.pytorch.ToTensor(dtype=torch.float)\n", + ")\n", + "# Sanity check.\n", + "sanity_check_pipeline_augmentation = np.squeeze(pipeline.update().resolve())\n", + "sanity_check_image_augmentation = sanity_check_pipeline_augmentation[0]\n", + "sanity_check_map_augmentation = sanity_check_pipeline_augmentation[1]\n", + "\n", + "# Plot the image as the first subplot.\n", + "utils.plot_image_mask_ground_truth_map(\n", + " image=sanity_check_image_augmentation,\n", + " gt_map=sanity_check_map_augmentation,\n", + " title=f\"Random Augmentation from Training Pipeline\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = dt.pytorch.Dataset(pipeline, length=256)\n", + "\n", + "train_loader = DataLoader(\n", + " train_dataset, \n", + " batch_size=8, \n", + " shuffle=True,\n", + ")\n", + "\n", + "trainer_unet = dl.Trainer(max_epochs=200, accelerator=\"auto\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zBNEX15S5idR" + }, + "source": [ + "Check whether pre-trained U-net weights already exist on disk. If they are missing or if training is explicitly forced, the model is trained on the experimental crops using the specified training pipeline, and the resulting weights are saved for future use.\n", + "\n", + "If the weights are already available and training is not forced, they are simply loaded from file.\n", + "\n", + "Afterward, the model is set to evaluation mode, which disables training-specific behaviors, ensuring consistent behavior during inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bBIHM4HFFKZa" + }, + "outputs": [], + "source": [ + "# Force training if desired.\n", + "force_training = True\n", + "\n", + "# Define the file paths for the model weights.\n", + "unet_path = \"UNet_model_spheres.pth\"\n", + "regressor_unet_path = \"UNet_reg_spheres.pth\"\n", + "\n", + "# Define folder and construct full file paths.\n", + "folder_name = \"UNet\"\n", + "unet_path = os.path.join(folder_name, unet_path)\n", + "regressor_unet_path = os.path.join(folder_name, regressor_unet_path)\n", + "\n", + "# Load preexisting weights if they exist and training is not forced.\n", + "if (not force_training and os.path.exists(unet_path)\n", + " and os.path.exists(regressor_unet_path)):\n", + " unet.load_state_dict(torch.load(unet_path, weights_only=True))\n", + " regressor_unet.load_state_dict(\n", + " torch.load(regressor_unet_path, weights_only=True)\n", + " )\n", + " print(f\"Loaded preexisting U-Net weights from '{folder_name}/'.\")\n", + "else:\n", + " print(\"Training U-Net model (either forced or weights not found).\")\n", + " \n", + " # Ensure the save directory exists.\n", + " os.makedirs(folder_name, exist_ok=True)\n", + "\n", + " # Train the U-Net model.\n", + " trainer_unet.fit(regressor_unet, train_loader)\n", + "\n", + " # Monitor training history.\n", + " trainer_unet.history.plot()\n", + " \n", + " # Save the weights.\n", + " torch.save(unet.state_dict(), unet_path)\n", + " torch.save(regressor_unet.state_dict(), regressor_unet_path)\n", + " print(f\"Saved trained U-Net weights to '{folder_name}/'.\")\n", + " \n", + "# Transfer the model to the best available device (optional).\n", + "regressor_unet.to(device);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluating U-net on Simulated Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the trained U-net model to a frame of the simulated video. Set the inference parameters that control detection sensitivity. Get the prediction features, which can be useful for visualizing detections. Extract the final coordinates of the detected particles and print how many particles were detected. Plot the predicted localizations from U-Net alongside the ground truth on the simulated image and quantify the performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extract intensity corresponding to first and 99th percentile of intensity distribution.\n", + "p1, p99 = np.percentile(sim_video[0], [1, 99])\n", + "\n", + "# Apply contrast stretching to enhance contrast.\n", + "test_frame = np.clip(sim_video[0], p1, p99)\n", + "\n", + "# Normalize intensity to [0,1] for inference with U-Net.\n", + "test_frame = utils.normalize_min_max(test_frame)\n", + "\n", + "# Convert the image to analyze into a PyTorch tensor.\n", + "formatted_sim_image = utils.format_image(test_frame)\n", + "\n", + "# Transfer the tensor to the best available device (optional).\n", + "formatted_sim_image = formatted_sim_image.to(device)\n", + "\n", + "# Apply the UNet to the loaded image.\n", + "sim_image_pred_map_tensor = regressor_unet(formatted_sim_image.to(device))\n", + "\n", + "# Convert to NumPy array.\n", + "sim_image_pred_map = sim_image_pred_map_tensor[0, 0, :, :].cpu().detach().numpy()\n", + "\n", + "# Normalize the predicted ground truth map for thresholding purposes.\n", + "sim_image_pred_map = utils.normalize_min_max(sim_image_pred_map, minimum_value=0.0, maximum_value=1.0)\n", + "\n", + "# Apply contrast stretching to the predicted ground truth map for automatic thresholding.\n", + "p1, p98 = np.percentile(sim_image_pred_map, [1, 98])\n", + "\n", + "# Apply thresholding to the predicted ground truth map.\n", + "sim_image_pred_mask_unet = sim_image_pred_map > p98\n", + "\n", + "# Convert the masked ground truth map to positions.\n", + "sim_locs_pred_method = \\\n", + " utils.mask_to_positions(sim_image_pred_mask_unet, sim_image_pred_map)\n", + "\n", + "# Plot the simulated image with the positions predicted by U-Net.\n", + "utils.plot_predicted_positions(\n", + " image=test_frame, \n", + " pred_positions=sim_locs_pred_method, \n", + " gt_positions=sim_trajs_gt[0][:,:2],\n", + " title=\"Method 3 - Simulated Image\",\n", + ")\n", + "\n", + "# Plot the predicted ground truth map and its masked version.\n", + "utils.plot_image_mask_ground_truth_map(\n", + " mask=sim_image_pred_mask_unet,\n", + " gt_map=sim_image_pred_map,\n", + " title=\"Method 3 - Simulated Image\",\n", + ")\n", + "\n", + "# Print the number of detections.\n", + "print(f\"Found {len(sim_locs_pred_method[:,1])} detections.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zMrw8fRXFKZb" + }, + "source": [ + "### Evaluating U-net on Experimental Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the trained U-net model to a frame of the experimental video. Set the inference parameters that control detection sensitivity. Get the prediction features, which can be useful for visualizing detections. Extract the final coordinates of the detected particles and print how many particles were detected. Plot the predicted localizations from U-Net alongside the ground truth on the simulated image and quantify the performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 428 + }, + "id": "irPMdWcLc75g", + "outputId": "d30c5e20-1d0a-45e2-8073-f185c6e027ed" + }, + "outputs": [], + "source": [ + "# Extract intensity corresponding to first and 98th percentile of intensity distribution.\n", + "p1, p98 = np.percentile(exp_video[0], [1, 98])\n", + "\n", + "# Apply contrast stretching to enhance contrast.\n", + "test_frame = np.clip(exp_video[0], p1, p98)\n", + "\n", + "# Normalize intensity to [0,1] for inference with U-Net.\n", + "test_frame = utils.normalize_min_max(test_frame)\n", + "\n", + "# Convert the image to analyze into a PyTorch tensor.\n", + "formatted_exp_image = utils.format_image(test_frame)\n", + "\n", + "# Transfer the tensor to the best available device (optional).\n", + "formatted_exp_image = formatted_exp_image.to(device)\n", + "\n", + "# Apply the UNet to the loaded image.\n", + "exp_image_pred_map_tensor = regressor_unet(formatted_exp_image.to(device))\n", + "\n", + "# Convert to NumPy array.\n", + "exp_image_pred_map = exp_image_pred_map_tensor[0, 0, :, :].cpu().detach().numpy()\n", + "\n", + "# Normalize the predicted ground truth map for thresholding purposes.\n", + "exp_image_pred_map = utils.normalize_min_max(exp_image_pred_map, minimum_value=0.0, maximum_value=1.0)\n", + "\n", + "# Apply contrast stretching to the predicted ground truth map for automatic thresholding.\n", + "p1, p98 = np.percentile(exp_image_pred_map, [1, 98])\n", + "\n", + "# Apply thresholding to the predicted ground truth map.\n", + "exp_image_pred_mask_unet = exp_image_pred_map > p98\n", + "\n", + "# Convert the masked ground truth map to positions.\n", + "exp_locs_pred_method = \\\n", + " utils.mask_to_positions(exp_image_pred_mask_unet, exp_image_pred_map)\n", + "\n", + "# Plot the experimental image with the positions predicted by U-Net.\n", + "utils.plot_predicted_positions(\n", + " image=test_frame, \n", + " pred_positions=exp_locs_pred_method, \n", + " title=\"Method 3 - Experimental Image\",\n", + ")\n", + "\n", + "# Plot the predicted ground truth map and its masked version.\n", + "utils.plot_image_mask_ground_truth_map(\n", + " mask=exp_image_pred_mask_unet,\n", + " gt_map=exp_image_pred_map,\n", + " title=\"Method 3 - Experimental Image\",\n", + ")\n", + "\n", + "# Print the number of detections.\n", + "print(f\"Found {len(exp_locs_pred_method[:,1])} detections.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Applying U-Net to a Simulated Video\n", + "Iteratively apply U-Net to every frame of the simulated video and store localizations in a dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_sim_video = []\n", + "for frame_index, sim_frame in enumerate(sim_video):\n", + "\n", + " # Extract intensity corresponding to first and 99th percentile of intensity distribution.\n", + " p1, p99 = np.percentile(sim_frame, [1, 99])\n", + "\n", + " # Apply contrast stretching to enhance contrast.\n", + " sim_frame = np.clip(sim_frame, p1, p99)\n", + "\n", + " # Normalize intensity to [0,1] for inference with U-Net.\n", + " sim_frame = utils.normalize_min_max(sim_frame, minimum_value=0.0, maximum_value=1.0)\n", + "\n", + " # Convert the image to analyze into a PyTorch tensor.\n", + " sim_frame_formatted = utils.format_image(sim_frame)\n", + "\n", + " # Transfer the tensor to the best available device (optional).\n", + " sim_frame_formatted = sim_frame_formatted.to(device)\n", + "\n", + " # Apply the UNet to the loaded image.\n", + " sim_image_pred_map_tensor = regressor_unet(sim_frame_formatted)\n", + "\n", + " # Convert to NumPy array.\n", + " sim_image_pred_map = sim_image_pred_map_tensor[0, 0, :, :].cpu().detach().numpy()\n", + "\n", + " # Normalize the predicted ground truth map for thresholding purposes.\n", + " sim_image_pred_map = utils.normalize_min_max(sim_image_pred_map, minimum_value=0.0, maximum_value=1.0)\n", + "\n", + " # Apply contrast stretching to the predicted ground truth map for automatic thresholding.\n", + " p1, p98 = np.percentile(sim_image_pred_map, [1, 98])\n", + " \n", + " # Apply thresholding to the predicted ground truth map.\n", + " sim_image_pred_mask_unet = sim_image_pred_map > p98\n", + "\n", + " # Convert the masked ground truth map to positions.\n", + " sim_locs_pred = \\\n", + " utils.mask_to_positions(sim_image_pred_mask_unet, sim_image_pred_map)\n", + "\n", + " # Store detections in a DataFrame.\n", + " df_frame = pd.DataFrame(sim_locs_pred, columns=[\"x\", \"y\"])\n", + " df_frame[\"frame\"] = frame_index\n", + " df_sim_video.append(df_frame)\n", + "\n", + " # Print no. of detections every 10 frames.\n", + " if frame_index % 10 == 0:\n", + " print(f\"Detections in frame {frame_index}: {len(sim_locs_pred)}\")\n", + "\n", + "# Combine all detections into a single DataFrame.\n", + "df_sim_video = pd.concat(df_sim_video, ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Applying U-Net to an Experimental Video\n", + "Iteratively apply U-Net to every frame of the experimental video and store localizations in a dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_exp_video = []\n", + "for frame_index, exp_frame in enumerate(exp_video):\n", + "\n", + " # Extract intensity corresponding to first and 98th percentile of intensity distribution.\n", + " p1, p98 = np.percentile(exp_frame, [1, 99])\n", + "\n", + " # Apply contrast stretching to enhance contrast.\n", + " exp_frame = np.clip(exp_frame, p1, p98)\n", + "\n", + " # Normalize intensity to [0,1] for inference with U-Net.\n", + " exp_frame = utils.normalize_min_max(exp_frame, minimum_value=0.0, maximum_value=1.0)\n", + "\n", + " # Convert the image to analyze into a PyTorch tensor.\n", + " formatted_exp_image = utils.format_image(exp_frame)\n", + "\n", + " # Transfer the tensor to the best available device (optional).\n", + " formatted_exp_image = formatted_exp_image.to(device)\n", + "\n", + " # Apply the UNet to the loaded image.\n", + " exp_image_pred_map_tensor = regressor_unet(formatted_exp_image)\n", + "\n", + " # Convert to NumPy array.\n", + " exp_image_pred_map = exp_image_pred_map_tensor[0, 0, :, :].cpu().detach().numpy()\n", + "\n", + " # Normalize the predicted ground truth map for thresholding purposes.\n", + " exp_image_pred_map = utils.normalize_min_max(exp_image_pred_map, minimum_value=0.0, maximum_value=1.0)\n", + "\n", + " # Apply contrast stretching to the predicted ground truth map for automatic thresholding.\n", + " p1, p98 = np.percentile(exp_image_pred_map, [1, 98])\n", + "\n", + " # Apply thresholding to the predicted ground truth map.\n", + " exp_image_pred_mask_unet = exp_image_pred_map > p98\n", + "\n", + " # Convert the masked ground truth map to positions.\n", + " exp_locs_pred = \\\n", + " utils.mask_to_positions(exp_image_pred_mask_unet, exp_image_pred_map)\n", + "\n", + " # Store detections in a DataFrame.\n", + " df_frame = pd.DataFrame(exp_locs_pred, columns=[\"x\", \"y\"])\n", + " df_frame[\"frame\"] = frame_index\n", + " df_exp_video.append(df_frame)\n", + "\n", + " # Print no. of detections every 10 frames.\n", + " if frame_index % 10 == 0:\n", + " print(f\"Detections in frame {frame_index}: {len(exp_locs_pred)}\")\n", + "\n", + "# Combine all detections into a single DataFrame.\n", + "df_exp_video = pd.concat(df_exp_video, ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6oMlJeXkQOZZ" + }, + "source": [ + "## Method 1: Nearest-Neighbor Linking with TrackPy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TrackPy constructs trajectories by linking localized positions across frames using a predictive nearest-neighbor algorithm. The input must be a Pandas DataFrame containing the particle positions, usually with columns `x`, `y`, and `frame`.\n", + "\n", + "See the [TrackPy tutorial on prediction and linking](https://soft-matter.github.io/trackpy/dev/tutorial/prediction.html) for more details on how the algorithm works and how to tune parameters like `search_range` and `memory`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Linking Localizations in Simulations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the method to the localization dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "q8DHTIAUQOZZ", + "outputId": "01b92ec1-c2dc-4d78-dc25-11cb62ae8ce5" + }, + "outputs": [], + "source": [ + "# Link detections across frames into trajectories using trackpy.link().\n", + "# The `search_range` parameter sets the maximum allowed displacement in pixels\n", + "# between frames, and the `memory` parameter allows particles to vanish for a\n", + "# given number of frames and still be linked to the same trajectory.\n", + "sim_trajs_pred_method1 = tp.link(\n", + " df_sim_video,\n", + " search_range=3,\n", + " memory=4,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "alwnOK-jQOZa" + }, + "source": [ + "Create a trajectory list from the output dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Gwni1QQXQOZa" + }, + "outputs": [], + "source": [ + "sim_trajs_pred_method1_list = []\n", + "for i in sim_trajs_pred_method1.particle.drop_duplicates():\n", + " traj = sim_trajs_pred_method1.loc[\n", + " sim_trajs_pred_method1.particle == i,\n", + " [\"frame\", \"x\", \"y\"]\n", + " ].values\n", + " sim_trajs_pred_method1_list.append(traj)\n", + "\n", + "print(f\"Number of trajectories found: {len(sim_trajs_pred_method1_list)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Filter out trajectories shorter than 10 frames" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filter trajectory lists shorter than 10 frames.\n", + "sim_trajs_pred_method1_list = [trajectory for trajectory in sim_trajs_pred_method1_list if len(trajectory) >= 15]\n", + "\n", + "print(f\"Number of trajectories found: {len(sim_trajs_pred_method1_list)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7vZUMzWYQOZa" + }, + "source": [ + "Create a video with overlayed localizations and trajectories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sim_video_method1_results = utils.make_video_with_trajs(\n", + " trajs_pred_list=sim_trajs_pred_method1_list,\n", + " video=sim_video,\n", + " fov_size=sim_image_size,\n", + " trajs_gt_list=sim_trajs_gt_list,\n", + " figure_title=\"Simulated video\"\n", + ")\n", + "\n", + "# Display the video.\n", + "sim_video_method1_results\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SzFm5YceQOZb" + }, + "source": [ + "### Evaluating Linking Performance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluate the overall performance of the tracking (detection + linking) method using the following metrics:\n", + "\n", + "- **TP (True Positives):** Number of ground-truth particles correctly matched to estimated positions.\n", + "\n", + "- **FP (False Positives):** Number of estimated particles that do not correspond to any ground-truth particle.\n", + "\n", + "- **FN (False Negatives):** Number of ground-truth particles that were not matched to any estimated position.\n", + "\n", + "- **α:** A measure of the overall agreement between ground-truth and estimated tracks, ignoring unmatched (spurious) estimated tracks.\n", + "\n", + "- **β:** A stricter version of α that penalizes unmatched (spurious) tracks, providing a more realistic performance score. \n", + "\n", + "See the detailed definitions in [Chenouard et al., Nature Methods, 2014](https://www.nature.com/articles/nmeth.2808).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nk-MJ37YQOZb", + "outputId": "6406de28-a117-4201-981e-7a5c9b71145a" + }, + "outputs": [], + "source": [ + "# Evaluate performance metrics.\n", + "utils.trajectory_metrics(\n", + " sim_trajs_gt_list,\n", + " sim_trajs_pred_method1_list,\n", + " eps=5,\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bOwhbz58QOZc" + }, + "source": [ + "Display the reconstructed trajectories together with their groud truth." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 522 + }, + "id": "sCM3GhhBQOZf", + "outputId": "103e1986-3966-467e-893b-a1a24517ce66" + }, + "outputs": [], + "source": [ + "# Compute the total squared distance between all trajectories to match\n", + "# predicted trajectories with ground truth.\n", + "matched_pairs, _, _ = utils.trajectory_assignment(\n", + " sim_trajs_gt_list,\n", + " sim_trajs_pred_method1_list,\n", + " eps=5,\n", + ")\n", + "\n", + "# Plot the trajectories.\n", + "utils.plot_trajectory_matches(\n", + " sim_trajs_gt_list, sim_trajs_pred_method1_list, matched_pairs,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calculate the time-averaged MSD for all the trajectories and compare curves obtained for matching trajectories (same color)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "utils.plot_TAMSDs(\n", + " trajs_pred=sim_trajs_pred_method1_list,\n", + " trajs_gt=sim_trajs_gt_list,\n", + " matched_pairs=matched_pairs,\n", + ") " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Linking Localizations in Experiments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the same steps to track the experiment and visualize the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Link detections across frames into trajectories using trackpy.link().\n", + "exp_trajs_pred_method1 = tp.link(\n", + " df_exp_video,\n", + " search_range=3, \n", + " memory=3,\n", + ")\n", + "# Create a list to store trajectories.\n", + "exp_trajs_pred_method1_list = []\n", + "for i in exp_trajs_pred_method1.particle.drop_duplicates():\n", + " traj = exp_trajs_pred_method1.loc[\n", + " exp_trajs_pred_method1.particle == i,\n", + " [\"frame\", \"x\", \"y\"]\n", + " ].values\n", + " exp_trajs_pred_method1_list.append(traj)\n", + "\n", + "# Filter trajectory lists shorter than 10 frames.\n", + "exp_trajs_pred_method1_list = [trajectory for trajectory in exp_trajs_pred_method1_list if len(trajectory) >= 15]\n", + "print(f\"Number of trajectories found: {len(exp_trajs_pred_method1_list)}\")\n", + "\n", + "# Display the experimental video with the localizations and trajectories.\n", + "exp_video_method1_results = utils.make_video_with_trajs(\n", + " trajs_pred_list=exp_trajs_pred_method1_list,\n", + " video=exp_video,\n", + " fov_size=exp_image_size,\n", + ")\n", + "\n", + "# Display the video.\n", + "exp_video_method1_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calculate the time-averaged MSD for the trajectories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "utils.plot_TAMSDs(trajs_pred=exp_trajs_pred_method1_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WRf0iHOZQOZf" + }, + "source": [ + "## Method 2: Linear Assignment Problem (LAP) with LapTrack\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "LapTrack solves the trajectory linking problem by formulating it as a Linear Assignment Problem (LAP), a well-established optimization approach in multi-object tracking. It builds a cost matrix that quantifies dissimilarity—typically based on spatial distance—between particle detections in consecutive frames. Lower distances correspond to lower costs and indicate higher likelihoods of correspondence.\n", + "\n", + "LapTrack uses the Hungarian algorithm to solve this assignment problem efficiently, minimizing the total cost across the matrix. This allows it to determine the globally optimal set of assignments across frames, enabling robust trajectory reconstruction even under challenging conditions such as high particle density or noisy detections.\n", + "\n", + "Examples and tutorials using LapTrack are available in the [LapTrack documentation](https://github.com/yfukai/laptrack/tree/main/docs/examples)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1uo-448OQOZf" + }, + "source": [ + "### Linking Localizations in Simulations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the method to the localization dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 262 + }, + "id": "4mDW4fARQOZf", + "outputId": "66a9f87a-1447-487a-b158-778243f5cbd5" + }, + "outputs": [], + "source": [ + "# Define the LapTrack instance to later link the detections.\n", + "laptrack = LapTrack(\n", + " track_cost_cutoff=3**2, # Maximum allowed distance in pixels for linking detections across frames.\n", + " gap_closing_cost_cutoff=5**2, # Max maximum allowed linking cost for missing detections.\n", + " gap_closing_max_frame_count=5, # Maximum number of missing frames, or \"memory\".\n", + " splitting_cutoff=False, # Disable cell division-like events.\n", + ")\n", + "\n", + "# Fetch the predicted trajectories as the first output of the function.\n", + "sim_trajs_pred_method2, _, _ = laptrack.predict_dataframe(\n", + " df=df_sim_video, # DataFrame with detections\n", + " coordinate_cols=[\"x\", \"y\"], # Name of columns containing coordinates\n", + ")\n", + "\n", + "# Reset the indexing order to ensure sequential trajectory IDs.\n", + "sim_trajs_pred_method2 = sim_trajs_pred_method2.reset_index()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZXP-woQlQOZf" + }, + "source": [ + "Create a trajectory list from the output dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Sqc3NWqhQOZg" + }, + "outputs": [], + "source": [ + "sim_trajs_pred_method2_list=[]\n", + "\n", + "# Eliminate duplicates in track_id and create a list of trajectories.\n", + "for i in sim_trajs_pred_method2.track_id.drop_duplicates():\n", + " traj = sim_trajs_pred_method2.loc[\n", + " sim_trajs_pred_method2.track_id == i,\n", + " [\"frame\", \"x\", \"y\"]\n", + " ].values\n", + " sim_trajs_pred_method2_list.append(traj)\n", + "\n", + "# Filter trajectory lists shorter than 10 frames.\n", + "sim_trajs_pred_method2_list = [trajectory for trajectory in sim_trajs_pred_method2_list if len(trajectory) >= 15]\n", + "print(f\"Number of trajectories found: {len(sim_trajs_pred_method2_list)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "caLZE0LCQOZg" + }, + "source": [ + "Create a video with overlayed localizations and trajectories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 596 + }, + "id": "yPjFBBGwQOZg", + "outputId": "16b4baaf-8dca-4a9b-db57-fdc1c8e80350" + }, + "outputs": [], + "source": [ + "sim_video_method2_results = utils.make_video_with_trajs(\n", + " trajs_pred_list=sim_trajs_pred_method2_list,\n", + " video=sim_video,\n", + " fov_size=sim_image_size,\n", + " trajs_gt_list=sim_trajs_gt_list,\n", + ")\n", + "\n", + "# Display the video.\n", + "sim_video_method2_results" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YpEBsU6kQOZg" + }, + "source": [ + "### Evaluating Linking Performance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluate the overall performance of the tracking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uwU4hSwWQOZh", + "outputId": "2618676e-40b2-413d-d75f-17c25c95562a" + }, + "outputs": [], + "source": [ + "# Evaluate performance metrics.\n", + "utils.trajectory_metrics(\n", + " sim_trajs_gt_list,\n", + " sim_trajs_pred_method2_list,\n", + " eps=5,\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the reconstructed trajectories together with the ground truth." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 522 + }, + "id": "R7cu60ZWFKZm", + "outputId": "43c30302-d3df-4959-dcd9-9a12e8dfaaf7" + }, + "outputs": [], + "source": [ + "# Compute the total squared distance between all trajectories to match\n", + "# predicted trajectories with ground truth.\n", + "matched_pairs, _, _ = utils.trajectory_assignment(\n", + " sim_trajs_gt_list,\n", + " sim_trajs_pred_method2_list,\n", + " eps=5,\n", + ")\n", + "\n", + "# Plot the trajectories.\n", + "utils.plot_trajectory_matches(\n", + " sim_trajs_gt_list, sim_trajs_pred_method2_list, matched_pairs,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calculate the time-averaged MSD for all the trajectories and compare curves obtained for matching trajectories (same color)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "utils.plot_TAMSDs(\n", + " trajs_pred=sim_trajs_pred_method2_list,\n", + " trajs_gt=sim_trajs_gt_list,\n", + " matched_pairs=matched_pairs,\n", + ") " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Linking Localizations in Experiments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the same steps to track the experiment and visualize the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Link detections across frames into trajectories using laptrack.\n", + "exp_trajs_pred_method2, _, _ = laptrack.predict_dataframe(\n", + " df_exp_video,\n", + " [\"x\", \"y\"],\n", + " only_coordinate_cols=True,\n", + ")\n", + "\n", + "exp_trajs_pred_method2 = exp_trajs_pred_method2.reset_index()\n", + "\n", + "# Create a list to store trajectories.\n", + "exp_trajs_pred_method2_list=[]\n", + "for i in exp_trajs_pred_method2.track_id.drop_duplicates():\n", + " traj = exp_trajs_pred_method2.loc[\n", + " exp_trajs_pred_method2.track_id == i,\n", + " [\"frame\", \"x\", \"y\"]\n", + " ].values\n", + " exp_trajs_pred_method2_list.append(traj)\n", + "\n", + "# Filter trajectory lists shorter than 10 frames.\n", + "exp_trajs_pred_method2_list = [trajectory for trajectory in exp_trajs_pred_method2_list if len(trajectory) >= 15]\n", + "print(f\"Number of trajectories found: {len(exp_trajs_pred_method2_list)}\")\n", + "\n", + "# Visualize the video with the trajectories overlaid.\n", + "exp_video_method2_results = utils.make_video_with_trajs(\n", + " trajs_pred_list=exp_trajs_pred_method2_list,\n", + " video=exp_video,\n", + " fov_size=exp_image_size,\n", + ")\n", + "\n", + "# Display the video.\n", + "exp_video_method2_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calculate the time-averaged MSD for the trajectories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "utils.plot_TAMSDs(trajs_pred=exp_trajs_pred_method2_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r45j5Z2LQOZh" + }, + "source": [ + "## Method 3: MAGIK" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "MAGIK is a tracking framework designed to analyze the motion of dynamic systems, including cells, bacteria, individual molecules, colloids, and other active particles. The name stands for Motion Analysis through Graph Neural Network Inductive Knowledge.\n", + "\n", + "At its core, MAGIK uses graph neural networks (GNNs) to learn patterns in particle movement and to infer trajectories across frames. This data-driven approach enables MAGIK to outperform traditional tracking methods in challenging conditions, such as dense particle fields, complex interaction dynamics, or non-Brownian motion.\n", + "\n", + "Thanks to its ability to learn and generalize motion priors, MAGIK is particularly effective in noisy or ambiguous experimental settings, making it a strong complement—or even an alternative—to classical tools like TrackPy and LapTrack.\n", + "\n", + "For more details, see the publication \n", + "[Geometric Deep Learning Reveals the Spatiotemporal Features of Microscopic Motion](https://www.nature.com/articles/s42256-022-00595-0) *Nat Mach Intell* **5**, 71–82 (2023)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training MAGIK with Simulations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To train MAGIK effectively, you first need to generate appropriate training data. Since the experimental dataset in this case features colloids undergoing Brownian motion, you'll use the `simulate_Brownian_trajs()` function to produce groups of synthetic trajectories that replicate this behavior.\n", + "\n", + "**Note:** It is crucial that the simulated training data accurately reflect the motion characteristics of your experimental particles. If your system exhibits pure diffusion (Brownian motion), the training simulations should mirror that. Conversely, if your experimental data involve additional dynamics—such as drift, confinement, or driven flow (e.g. in nanofluidic systems)—these should be incorporated into the training data to ensure MAGIK learns the correct motion priors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 571 + }, + "id": "j1OEgXgKQOZl", + "outputId": "d718a10c-227c-405e-cedd-1014c22c1cb0" + }, + "outputs": [], + "source": [ + "# Parameters for training dataset.\n", + "train_dataset_size = 100 # Number of videos\n", + "train_number_particles = 20 # Number of particles per video\n", + "train_number_timesteps = 50 # Number of frames per video\n", + "train_fov_size = 128 # Size of the field of view\n", + "\n", + "# Initiate a dataframe containing all the simulated trajectories.\n", + "df_train_dataset = []\n", + "for video_index in range(train_dataset_size):\n", + " # Simulate trajectories for one video.\n", + " sim_trajs_train_dataset = utils.simulate_Brownian_trajs(\n", + " num_particles=train_number_particles,\n", + " num_timesteps=train_number_timesteps,\n", + " fov_size=train_fov_size,\n", + " diffusion_std=np.random.uniform(0.1, 2.0),\n", + " )\n", + " # Break trajectories going in/out of FOV.\n", + " sim_trajs_train_dataset_list = utils.traj_break(\n", + " trajs=sim_trajs_train_dataset,\n", + " fov_size=train_fov_size,\n", + " num_particles=train_number_particles,\n", + " )\n", + " # Make into dataframe with \"frame\" (which frame in the video),\n", + " # label(which particle in that frame), set (which video).\n", + " for traj_index, traj in enumerate(sim_trajs_train_dataset_list):\n", + " df_traj = pd.DataFrame(\n", + " traj[:, 1:],\n", + " columns=[\"centroid-0\", \"centroid-1\"],\n", + " )\n", + " df_traj[\"frame\"] = traj[:, 0].astype(int)\n", + " df_traj[\"label\"] = traj_index\n", + " df_traj[\"set\"] = f\"{video_index}\"\n", + " df_train_dataset.append(df_traj)\n", + "\n", + "# Concatenate to dataframe.\n", + "df_train_dataset = pd.concat(df_train_dataset, ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Normalize the trajectory coordinates between 0 and 1 by dividing for the fov size." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Normalize centroids between 0 and 1.\n", + "norm_factor = np.array([train_fov_size, train_fov_size])\n", + "df_train_dataset.loc[:, df_train_dataset.columns.str.contains(\"centroid\")] = (\n", + " df_train_dataset.loc[\n", + " :,\n", + " df_train_dataset.columns.str.contains(\"centroid\")\n", + " ] / norm_factor\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6HO6aEdSQOZm" + }, + "source": [ + "To train MAGIK with the simulated trajectories, you need to produce a graph representation with the function `GraphFromTrajectories()`, defined in the utility file `utils.py`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NIg528NaQOZm", + "outputId": "fbcbda7f-b87c-4dea-ee0e-8976ec4e604b" + }, + "outputs": [], + "source": [ + "# Instance the graph constructor.\n", + "graph_constructor = utils.GraphFromTrajectories(\n", + " connectivity_radius=0.01,\n", + " max_frame_distance=3,\n", + ")\n", + "\n", + "# Generate graph from training data using graph constructor.\n", + "train_dataset_graph = graph_constructor(df=df_train_dataset)\n", + "\n", + "print(train_dataset_graph)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uTmRBM2CnNCm" + }, + "source": [ + "Define the augmentation pipeline and the dataloader." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IQo0pfTbQOZm" + }, + "outputs": [], + "source": [ + "# Initialize the graph dataset class.\n", + "# `Dt` is the time difference between frames to sample from the graph.\n", + "# Specify augmentations with transform,\n", + "# NodeDropout() should be last.\n", + "train_dataset = utils.GraphDataset(\n", + " train_dataset_graph,\n", + " dataset_size=train_dataset_size,\n", + " Dt=5,\n", + " transform=Compose(\n", + " [\n", + " utils.RandomRotation(),\n", + " utils.RandomFlip(),\n", + " utils.NodeDropout(),\n", + " ]\n", + " )\n", + ")\n", + "\n", + "# Initialize the training data loader.\n", + "train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=4,\n", + " shuffle=True,\n", + " drop_last=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BT1k6wjBQOZn" + }, + "source": [ + "Define the hyperparameters of the architecture." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TBdnAwjgQOZn" + }, + "outputs": [], + "source": [ + "magik = dl.GraphToEdgeMAGIK(\n", + " [96,] * 4,1,\n", + " out_activation=torch.nn.Sigmoid,\n", + ")\n", + "\n", + "magik.encoder[0].configure(\n", + " hidden_features=[32, 64],\n", + " out_features=96,\n", + " out_activation=torch.nn.ReLU,\n", + ")\n", + "\n", + "magik.encoder[1].configure(\n", + " hidden_features=[32, 64],\n", + " out_features=96,\n", + " out_activation=torch.nn.ReLU,\n", + ")\n", + "\n", + "magik.head.configure(hidden_features=[64, 32]);\n", + "\n", + "classifier_magik = dl.BinaryClassifier(\n", + " model=magik,\n", + " optimizer=dl.Adam(lr=1e-3),\n", + ").build()\n", + "\n", + "\n", + "# Set training parameters and train.\n", + "trainer_magik = dl.Trainer(max_epochs=64, accelerator=\"auto\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jd08re_GQOZn" + }, + "source": [ + "Train the model or load preexisting weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define whether to force training even if weights exist.\n", + "force_training = False\n", + "\n", + "# Define folder and model path.\n", + "folder_name = \"MAGIK\"\n", + "magik_model_path = os.path.join(folder_name, \"magik_weights.pth\")\n", + "\n", + "# Train or load weights.\n", + "if not os.path.isfile(magik_model_path) or force_training:\n", + " print(\"Training MAGIK model (either forced or weights not found).\")\n", + "\n", + " # Ensure save directory exists.\n", + " os.makedirs(folder_name, exist_ok=True)\n", + "\n", + " # Train the model.\n", + " trainer_magik.fit(classifier_magik, train_loader)\n", + "\n", + " # Plot training history.\n", + " trainer_magik.history.plot()\n", + " \n", + " # Save trained weights.\n", + " torch.save(magik.state_dict(), magik_model_path)\n", + " print(f\"Saved MAGIK weights to '{magik_model_path}'.\")\n", + "else:\n", + " # Load pre-trained weights.\n", + " magik.load_state_dict(torch.load(magik_model_path, weights_only=True))\n", + " print(f\"Loaded preexisting MAGIK weights from '{magik_model_path}'.\")\n", + "\n", + "# Set the model to evaluation mode.\n", + "classifier_magik.eval();\n", + "\n", + "# Transfer the model to the best available device (optional).\n", + "classifier_magik.to(device);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wh6ldS5tQOZn" + }, + "source": [ + "### Linking Localizations in Simulations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You'll use the trajectories in the simulated video as the test dataset for the trained model of MAGIK. First, format the simulated dataframe. Then, convert the localizations corresponding to the simulated trajectories into a graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Rename for compatibility with label format.\n", + "df_sim_video_formatted = df_sim_video.rename(\n", + " columns={\"x\": \"centroid-0\", \"y\": \"centroid-1\"}\n", + ")\n", + "\n", + "# Add label, set, and solution columns.\n", + "df_sim_video_formatted[[\"label\", \"set\", \"solution\"]] = 0\n", + "\n", + "# Normalize coordinates to [0, 1].\n", + "frame_height, frame_width, _ = sim_image.shape\n", + "df_sim_video_formatted[[\"centroid-0\", \"centroid-1\"]] /= [frame_width,\n", + " frame_height]\n", + "\n", + "# Generate a graph from graph_constructor. As test_graph returns a list of\n", + "# graphs, we select the first element from the list as it only has 1 element.\n", + "sim_video_graph = graph_constructor(df=df_sim_video_formatted)[0]\n", + "\n", + "# Transfer graph to the best available device (optional).\n", + "sim_video_graph = sim_video_graph.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qVWmgFQvQOZo" + }, + "source": [ + "Apply the trained model of MAGIK to predict the edge features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 353 + }, + "id": "5lvw_Xc6QOZo", + "outputId": "9828a7c9-6457-4e4d-9180-0970e845207e" + }, + "outputs": [], + "source": [ + "# Perform prediction on test graph.\n", + "sim_trajs_edges_pred_method3 = classifier_magik(sim_video_graph)\n", + "\n", + "# Apply threshold to get binary edge predictions.\n", + "sim_trajs_edges_pred_method3 = \\\n", + " sim_trajs_edges_pred_method3.cpu().detach().numpy() > 0.5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get the ground-truth edge features and, as a first performance metrics, use the F1-score for the classification of the edges." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the ground truth edges from the graph.\n", + "sim_trajs_edges_gt = sim_video_graph.y.cpu() # Transfer to CPU\n", + "\n", + "# Compute the F1 score.\n", + "F1 = f1_score(sim_trajs_edges_gt, sim_trajs_edges_pred_method3)\n", + "print(f\"Test F1 score: {F1}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_MNRExDHQOZo" + }, + "source": [ + "The edge feature can be used to obtain the predicted trajectories using the dedicate class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0Ht5ZdaFQOZp" + }, + "outputs": [], + "source": [ + "# Compute the trajectories from the predicted edges.\n", + "trajectory_constructor = utils.ComputeTrajectories()\n", + "sim_trajs_pred_method3 = trajectory_constructor(\n", + " sim_video_graph.cpu(),\n", + " sim_trajs_edges_pred_method3.squeeze(),\n", + ")\n", + "\n", + "# Convert the predicted trajectories to a list format.\n", + "sim_trajs_pred_method3_list = utils.make_list(\n", + " sim_trajs_pred_method3, sim_video_graph, sim_image_size,\n", + ")\n", + "\n", + "# Filter trajectory lists shorter than 10 frames.\n", + "sim_trajs_pred_method3_list = [trajectory for trajectory in sim_trajs_pred_method3_list if len(trajectory) >= 15]\n", + "print(f\"Number of trajectories found: {len(sim_trajs_pred_method3_list)}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a video with overlayed localizations and trajectories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a video with the predicted trajectories.\n", + "sim_video_method3_results = utils.make_video_with_trajs(\n", + " trajs_pred_list=sim_trajs_pred_method3_list,\n", + " video=sim_video,\n", + " fov_size=sim_image_size,\n", + " trajs_gt_list=sim_trajs_gt_list,\n", + ")\n", + "\n", + "# Plot the video.\n", + "sim_video_method3_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluating Linking Performance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluate the overall performance of the tracking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluate performance metrics.\n", + "utils.trajectory_metrics(\n", + " sim_trajs_gt_list,\n", + " sim_trajs_pred_method3_list,\n", + " eps=5,\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the reconstructed trajectories together with the groud truth." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute the total squared distance between all trajectories to match\n", + "# predicted trajectories with ground truth.\n", + "matched_pairs, _, _ = utils.trajectory_assignment(\n", + " sim_trajs_gt_list,\n", + " sim_trajs_pred_method3_list,\n", + " eps=5,\n", + ")\n", + "\n", + "# Plot the trajectories.\n", + "utils.plot_trajectory_matches(\n", + " sim_trajs_gt_list, sim_trajs_pred_method3_list, matched_pairs,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calculate the time-averaged MSD for all the trajectories and compare curves obtained for matching trajectories (same color)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "utils.plot_TAMSDs(\n", + " trajs_pred = sim_trajs_pred_method3_list,\n", + " trajs_gt = sim_trajs_gt_list,\n", + " matched_pairs = matched_pairs,\n", + ") " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "acu57jkGQOZq" + }, + "source": [ + "### Linking Localizations in Experiments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the same steps to track the experiment and visualize the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ti6PUVMLQOZw" + }, + "outputs": [], + "source": [ + "# Rename for compatibility with label format.\n", + "df_exp_video_formatted = df_exp_video.rename(\n", + " columns={\"x\": \"centroid-0\", \"y\": \"centroid-1\"}\n", + ")\n", + "\n", + "# Add label, set, and solution columns.\n", + "df_exp_video_formatted[[\"label\", \"set\", \"solution\"]] = 0\n", + "\n", + "# Normalize coordinates to [0, 1].\n", + "frame_height, frame_width, _ = sim_image.shape\n", + "df_exp_video_formatted[[\"centroid-0\", \"centroid-1\"]] /= [frame_width,\n", + " frame_height]\n", + "\n", + "# Generate a graph from graph_constructor. As test_graph returns a list of\n", + "# graphs, we select the first element from the list as it only has 1 element.\n", + "exp_video_graph = graph_constructor(df=df_exp_video_formatted)[0].to(device)\n", + "\n", + "# Perform prediction on graph.\n", + "exp_trajs_edges_pred_method3 = classifier_magik(exp_video_graph)\n", + "exp_trajs_edges_pred_method3 = \\\n", + " exp_trajs_edges_pred_method3.cpu().detach().numpy() > 0.5\n", + "\n", + "# Compute the trajectories from the predicted edges.\n", + "trajectory_constructor = utils.ComputeTrajectories()\n", + "sim_trajs_pred_method3 = trajectory_constructor(\n", + " exp_video_graph.cpu(),\n", + " exp_trajs_edges_pred_method3.squeeze(),\n", + ")\n", + "\n", + "# Convert the predicted trajectories to a list format.\n", + "exp_trajs_pred_method3_list = utils.make_list(\n", + " sim_trajs_pred_method3, exp_video_graph, exp_image_size,\n", + ")\n", + "\n", + "# Filter trajectory lists shorter than 10 frames.\n", + "exp_trajs_pred_method3_list = [trajectory for trajectory in exp_trajs_pred_method3_list if len(trajectory) >= 15]\n", + "print(f\"Number of trajectories found: {len(exp_trajs_pred_method3_list)}\")\n", + "\n", + "# Create a video with the predicted trajectories.\n", + "exp_video_method3_results = utils.make_video_with_trajs(\n", + " trajs_pred_list=exp_trajs_pred_method3_list,\n", + " video=exp_video,\n", + " fov_size=exp_image_size,\n", + ")\n", + "\n", + "exp_video_method3_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calculate the time-averaged MSD for the trajectories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "utils.plot_TAMSDs(trajs_pred=exp_trajs_pred_method3_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "PT", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "a4dd91e5328d494ab4ce1517fe9fec01": { + "model_module": "@jupyter-widgets/output", + "model_module_version": "1.0.0", + "model_name": "OutputModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/output", + "_model_module_version": "1.0.0", + "_model_name": "OutputModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/output", + "_view_module_version": "1.0.0", + "_view_name": "OutputView", + "layout": "IPY_MODEL_a89634731d994290871e11f60cf0e5de", + "msg_id": "", + "outputs": [ + { + "data": { + "text/html": "
Epoch 199/199 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50/50 0:00:04 • 0:00:00 12.11it/s v_num: 1 train_loss_step: 0.000133\n                                                                                 trainBinaryAccuracy_step: 1       \n                                                                                 train_loss_epoch: 0.0175          \n                                                                                 trainBinaryAccuracy_epoch: 0.993  \n
\n", + "text/plain": "Epoch 199/199 \u001b[38;2;98;6;224m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[37m50/50\u001b[0m \u001b[38;5;245m0:00:04 • 0:00:00\u001b[0m \u001b[38;5;249m12.11it/s\u001b[0m \u001b[37mv_num: 1 train_loss_step: 0.000133\u001b[0m\n \u001b[37mtrainBinaryAccuracy_step: 1 \u001b[0m\n \u001b[37mtrain_loss_epoch: 0.0175 \u001b[0m\n \u001b[37mtrainBinaryAccuracy_epoch: 0.993 \u001b[0m\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ] + } + }, + "a89634731d994290871e11f60cf0e5de": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorial/utils/utils_plotting.py b/tutorial/utils/utils_plotting.py index 3982d33..2349dbf 100644 --- a/tutorial/utils/utils_plotting.py +++ b/tutorial/utils/utils_plotting.py @@ -51,6 +51,7 @@ import numpy as np from .utils_evaluation import compute_TAMSD +from .utils_imageproc import normalize_min_max def play_video( video: np.ndarray, @@ -522,8 +523,13 @@ def make_video_with_trajs( if figure_title is not None: ax.set_title(figure_title) + vmin, vmax = np.percentile(video, [1, 99]) + # Normalize video intensities for better plotting. + # for frame_idx in range(len(video)): + # video[frame_idx] = normalize_min_max(video[frame_idx], minimum_value=vmin, maximum_value=vmax) + # Image artist (static background per frame). - im = ax.imshow(video[0], cmap="gray", animated=True) + im = ax.imshow(video[0], cmap="gray", animated=True, vmin=vmin, vmax=vmax) # Predicted trajectories: one line + one scatter (last point) per traj. pred_lines, pred_scatters = [], []