import torch import sys sys.argv = [''] from sklearn.preprocessing import StandardScaler import pytorch_lightning as pl from torch.utils.data import DataLoader from lightning.pytorch.utilities.combined_loader import CombinedLoader import numpy as np from scipy.spatial import cKDTree import math from functools import partial from sklearn.cluster import KMeans, DBSCAN import matplotlib.pyplot as plt import pandas as pd from torch.utils.data import TensorDataset from train.parsers_sc import parse_args args = parse_args() class WeightedBranchedCellDataModule(pl.LightningDataModule): def __init__(self, args): super().__init__() self.save_hyperparameters() self.data_path = "./data/mouse_hematopoiesis.csv" self.batch_size = args.batch_size self.max_dim = args.dim self.whiten = args.whiten self.k = 20 self.n_samples = 1429 self.num_timesteps = 3 # t=0, t=1, t=2 self.split_ratios = args.split_ratios self.metric_clusters = args.metric_clusters self.args = args self._prepare_data() def _prepare_data(self): print("Preparing cell data in BranchedCellDataModule") df = pd.read_csv(self.data_path) # Build dictionary of coordinates by time coords_by_t = { t: df[df["samples"] == t][["x1","x2"]].values for t in sorted(df["samples"].unique()) } n0 = coords_by_t[0].shape[0] # Number of T=0 points self.n_samples = n0 # Update n_samples to match actual data if changes # Cluster the t=2 cells into two branches km = KMeans(n_clusters=2, random_state=42).fit(coords_by_t[2]) df2 = df[df["samples"] == 2].copy() df2["branch"] = km.labels_ cluster_counts = df2["branch"].value_counts().sort_index() print(cluster_counts) # Sample n0 points from each branch endpoints = {} for b in (0, 1): endpoints[b] = ( df2[df2["branch"] == b] .sample(n=n0, random_state=42)[["x1","x2"]] .values ) x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32) x1_1 = torch.tensor(endpoints[0], dtype=torch.float32) # Branch index x1_2 = torch.tensor(endpoints[1], dtype=torch.float32) # Branch index self.coords_t0 = x0 self.coords_t1 = x_inter self.coords_t2_1 = x1_1 self.coords_t2_2 = x1_2 self.time_labels = np.concatenate([ np.zeros(len(self.coords_t0)), # t=0 np.ones(len(self.coords_t1)), # t=1 np.ones(len(self.coords_t2_1)) * 2, # t=1 np.ones(len(self.coords_t2_2)) * 2, ]) split_index = int(n0 * self.split_ratios[0]) if n0 - split_index < self.batch_size: split_index = n0 - self.batch_size train_x0 = x0[:split_index] val_x0 = x0[split_index:] train_x1_1 = x1_1[:split_index] val_x1_1 = x1_1[split_index:] train_x1_2 = x1_2[:split_index] val_x1_2 = x1_2[split_index:] self.val_x0 = val_x0 train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5) train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5) val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5) val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5) if self.n_samples - split_index < self.batch_size: split_index = self.n_samples - self.batch_size self.train_dataloaders = { "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), } self.val_dataloaders = { "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), } all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())]) self.dataset = torch.tensor(all_data, dtype=torch.float32) self.tree = cKDTree(all_data) # if whitening is enabled, need to apply this to the full dataset #if self.whiten: #self.scaler = StandardScaler() #self.dataset = torch.tensor( #self.scaler.fit_transform(all_data), dtype=torch.float32 #) self.test_dataloaders = { "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), } # Metric Dataloader # K-means clustering of ALL points into 2 groups if self.metric_clusters == 3: km_all = KMeans(n_clusters=3, random_state=45).fit(self.dataset.numpy()) cluster_labels = km_all.labels_ cluster_0_mask = cluster_labels == 0 cluster_1_mask = cluster_labels == 1 cluster_2_mask = cluster_labels == 2 samples = self.dataset.cpu().numpy() cluster_0_data = samples[cluster_0_mask] cluster_1_data = samples[cluster_1_mask] cluster_2_data = samples[cluster_2_mask] self.metric_samples_dataloaders = [ DataLoader( torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False, ), DataLoader( torch.tensor(cluster_2_data, dtype=torch.float32), batch_size=cluster_2_data.shape[0], shuffle=False, drop_last=False, ), DataLoader( torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False, ), ] else: km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy()) cluster_labels = km_all.labels_ cluster_0_mask = cluster_labels == 0 cluster_1_mask = cluster_labels == 1 samples = self.dataset.cpu().numpy() cluster_0_data = samples[cluster_0_mask] cluster_1_data = samples[cluster_1_mask] self.metric_samples_dataloaders = [ DataLoader( torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False, ), DataLoader( torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False, ), ] def train_dataloader(self): combined_loaders = { "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), "metric_samples": CombinedLoader( self.metric_samples_dataloaders, mode="min_size" ), } return CombinedLoader(combined_loaders, mode="max_size_cycle") def val_dataloader(self): combined_loaders = { "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), "metric_samples": CombinedLoader( self.metric_samples_dataloaders, mode="min_size" ), } return CombinedLoader(combined_loaders, mode="max_size_cycle") def test_dataloader(self): combined_loaders = { "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), "metric_samples": CombinedLoader( self.metric_samples_dataloaders, mode="min_size" ), } return CombinedLoader(combined_loaders, mode="max_size_cycle") def get_manifold_proj(self, points): """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) @staticmethod def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): """ Apply local smoothing based on k-nearest neighbors in the full dataset This replaces the plane projection for 2D manifold regularization """ points_np = x.detach().cpu().numpy() _, idx = tree.query(points_np, k=k) nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) # Compute weighted average of neighbors dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) weights = torch.exp(-dists / temp) weights = weights / weights.sum(dim=1, keepdim=True) # Weighted average of neighbors smoothed = (weights * nearest_pts).sum(dim=1) # Blend original point with smoothed version alpha = 0.3 # How much smoothing to apply return (1 - alpha) * x + alpha * smoothed def get_timepoint_data(self): """Return data organized by timepoints for visualization""" return { 't0': self.coords_t0, 't1': self.coords_t1, 't2_1': self.coords_t2_1, 't2_2': self.coords_t2_2, 'time_labels': self.time_labels } class SingleBranchCellDataModule(pl.LightningDataModule): def __init__(self, args): super().__init__() self.save_hyperparameters() self.data_path = "./data/mouse_hematopoiesis.csv" self.batch_size = args.batch_size self.max_dim = args.dim self.whiten = args.whiten self.k = 20 self.n_samples = 1429 self.num_timesteps = 3 # t=0, t=1, t=2 self.split_ratios = args.split_ratios self.metric_clusters = 3 self.args = args self._prepare_data() def _prepare_data(self): print("Preparing cell data in BranchedCellDataModule") df = pd.read_csv(self.data_path) # Build dictionary of coordinates by time coords_by_t = { t: df[df["samples"] == t][["x1","x2"]].values for t in sorted(df["samples"].unique()) } n0 = coords_by_t[0].shape[0] # Number of T=0 points self.n_samples = n0 # Update n_samples to match actual data if changes x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32) x1 = torch.tensor(coords_by_t[2], dtype=torch.float32) # Branch index split_index = int(n0 * self.split_ratios[0]) if n0 - split_index < self.batch_size: split_index = n0 - self.batch_size train_x0 = x0[:split_index] val_x0 = x0[split_index:] train_x1 = x1[:split_index] val_x1 = x1[split_index:] self.val_x0 = val_x0 train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=0.5) val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=0.5) if self.n_samples - split_index < self.batch_size: split_index = self.n_samples - self.batch_size self.train_dataloaders = { "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), } self.val_dataloaders = { "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), } all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())]) self.dataset = torch.tensor(all_data, dtype=torch.float32) self.tree = cKDTree(all_data) # if whitening is enabled, need to apply this to the full dataset if self.whiten: self.scaler = StandardScaler() self.dataset = torch.tensor( self.scaler.fit_transform(all_data), dtype=torch.float32 ) self.test_dataloaders = { "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), } # Metric Dataloader # K-means clustering of ALL points into 2 groups km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy()) cluster_labels = km_all.labels_ cluster_0_mask = cluster_labels == 0 cluster_1_mask = cluster_labels == 1 samples = self.dataset.cpu().numpy() cluster_0_data = samples[cluster_0_mask] cluster_1_data = samples[cluster_1_mask] self.metric_samples_dataloaders = [ DataLoader( torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False, ), DataLoader( torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False, ), ] def train_dataloader(self): combined_loaders = { "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), "metric_samples": CombinedLoader( self.metric_samples_dataloaders, mode="min_size" ), } return CombinedLoader(combined_loaders, mode="max_size_cycle") def val_dataloader(self): combined_loaders = { "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), "metric_samples": CombinedLoader( self.metric_samples_dataloaders, mode="min_size" ), } return CombinedLoader(combined_loaders, mode="max_size_cycle") def test_dataloader(self): combined_loaders = { "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), "metric_samples": CombinedLoader( self.metric_samples_dataloaders, mode="min_size" ), } return CombinedLoader(combined_loaders, mode="max_size_cycle") def get_manifold_proj(self, points): """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) @staticmethod def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): """ Apply local smoothing based on k-nearest neighbors in the full dataset This replaces the plane projection for 2D manifold regularization """ points_np = x.detach().cpu().numpy() _, idx = tree.query(points_np, k=k) nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) # Compute weighted average of neighbors dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) weights = torch.exp(-dists / temp) weights = weights / weights.sum(dim=1, keepdim=True) # Weighted average of neighbors smoothed = (weights * nearest_pts).sum(dim=1) # Blend original point with smoothed version alpha = 0.3 # How much smoothing to apply return (1 - alpha) * x + alpha * smoothed def get_datamodule(): datamodule = WeightedBranchedCellDataModule(args) datamodule.setup(stage="fit") return datamodule