|
|
import torch |
|
|
import sys |
|
|
sys.argv = [''] |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
import pytorch_lightning as pl |
|
|
from torch.utils.data import DataLoader |
|
|
from lightning.pytorch.utilities.combined_loader import CombinedLoader |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from functools import partial |
|
|
from scipy.spatial import cKDTree |
|
|
from sklearn.cluster import KMeans |
|
|
from torch.utils.data import TensorDataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DrugResponseDataModule(pl.LightningDataModule): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
|
|
|
self.batch_size = args.batch_size |
|
|
self.max_dim = args.dim |
|
|
self.whiten = args.whiten |
|
|
self.split_ratios = args.split_ratios |
|
|
|
|
|
|
|
|
self.data_path = "/raid/st512/branchsbm/data/pca_and_leiden_labels.csv" |
|
|
self.num_timesteps = 2 |
|
|
self.args = args |
|
|
self._prepare_data() |
|
|
|
|
|
def _prepare_data(self): |
|
|
df = pd.read_csv(self.data_path, comment='#') |
|
|
df = df.iloc[:, 1:] |
|
|
df = df.replace('', np.nan) |
|
|
pc_cols = df.columns[:50] |
|
|
for col in pc_cols: |
|
|
df[col] = pd.to_numeric(df[col], errors='coerce') |
|
|
leiden_dmso_col = 'leiden_DMSO_TF_0.0uM' |
|
|
leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM' |
|
|
|
|
|
dmso_mask = df[leiden_dmso_col].notna() |
|
|
clonidine_mask = df[leiden_clonidine_col].notna() |
|
|
|
|
|
dmso_data = df[dmso_mask].copy() |
|
|
clonidine_data = df[clonidine_mask].copy() |
|
|
|
|
|
top_clonidine_clusters = ['0.0', '4.0'] |
|
|
|
|
|
x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]] |
|
|
x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]] |
|
|
x1_1_coords = x1_1_data[pc_cols].values |
|
|
x1_2_coords = x1_2_data[pc_cols].values |
|
|
x1_1_coords = x1_1_coords.astype(float) |
|
|
x1_2_coords = x1_2_coords.astype(float) |
|
|
|
|
|
target_size = min(len(x1_1_coords), len(x1_2_coords)) |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
if len(x1_1_coords) > target_size: |
|
|
idx1 = np.random.choice(len(x1_1_coords), target_size, replace=False) |
|
|
x1_1_coords = x1_1_coords[idx1] |
|
|
|
|
|
if len(x1_2_coords) > target_size: |
|
|
idx2 = np.random.choice(len(x1_2_coords), target_size, replace=False) |
|
|
x1_2_coords = x1_2_coords[idx2] |
|
|
|
|
|
dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts() |
|
|
|
|
|
|
|
|
largest_dmso_cluster = dmso_cluster_counts.index[0] |
|
|
dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster] |
|
|
|
|
|
dmso_coords = dmso_cluster_data[pc_cols].values |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
if len(dmso_coords) >= target_size: |
|
|
idx0 = np.random.choice(len(dmso_coords), target_size, replace=False) |
|
|
x0_coords = dmso_coords[idx0] |
|
|
else: |
|
|
|
|
|
remaining_needed = target_size - len(dmso_coords) |
|
|
other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster] |
|
|
other_dmso_coords = other_dmso_data[pc_cols].values |
|
|
|
|
|
if len(other_dmso_coords) >= remaining_needed: |
|
|
idx_other = np.random.choice(len(other_dmso_coords), remaining_needed, replace=False) |
|
|
x0_coords = np.vstack([dmso_coords, other_dmso_coords[idx_other]]) |
|
|
else: |
|
|
|
|
|
all_dmso_coords = dmso_data[pc_cols].values |
|
|
target_size = min(target_size, len(all_dmso_coords)) |
|
|
idx0 = np.random.choice(len(all_dmso_coords), target_size, replace=False) |
|
|
x0_coords = all_dmso_coords[idx0] |
|
|
|
|
|
|
|
|
if len(x1_1_coords) > target_size: |
|
|
idx1 = np.random.choice(len(x1_1_coords), target_size, replace=False) |
|
|
x1_1_coords = x1_1_coords[idx1] |
|
|
|
|
|
if len(x1_2_coords) > target_size: |
|
|
idx2 = np.random.choice(len(x1_2_coords), target_size, replace=False) |
|
|
x1_2_coords = x1_2_coords[idx2] |
|
|
|
|
|
self.n_samples = target_size |
|
|
|
|
|
x0 = torch.tensor(x0_coords, dtype=torch.float32) |
|
|
x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32) |
|
|
x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32) |
|
|
|
|
|
self.coords_t0 = x0 |
|
|
self.coords_t1 = torch.cat([x1_1, x1_2], dim=0) |
|
|
self.time_labels = np.concatenate([ |
|
|
np.zeros(len(self.coords_t0)), |
|
|
np.ones(len(self.coords_t1)), |
|
|
]) |
|
|
|
|
|
split_index = int(target_size * self.split_ratios[0]) |
|
|
|
|
|
if target_size - split_index < self.batch_size: |
|
|
split_index = target_size - self.batch_size |
|
|
|
|
|
train_x0 = x0[:split_index] |
|
|
val_x0 = x0[split_index:] |
|
|
train_x1_1 = x1_1[:split_index] |
|
|
val_x1_1 = x1_1[split_index:] |
|
|
train_x1_2 = x1_2[:split_index] |
|
|
val_x1_2 = x1_2[split_index:] |
|
|
|
|
|
self.val_x0 = val_x0 |
|
|
|
|
|
train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) |
|
|
train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5) |
|
|
train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5) |
|
|
|
|
|
val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) |
|
|
val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5) |
|
|
val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5) |
|
|
|
|
|
self.train_dataloaders = { |
|
|
"x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
|
|
"x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
|
|
"x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
|
|
} |
|
|
|
|
|
self.val_dataloaders = { |
|
|
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), |
|
|
"x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
|
|
"x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
|
|
} |
|
|
|
|
|
all_coords = df[pc_cols].dropna().values.astype(float) |
|
|
self.dataset = torch.tensor(all_coords, dtype=torch.float32) |
|
|
self.tree = cKDTree(all_coords) |
|
|
|
|
|
self.test_dataloaders = { |
|
|
"x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), |
|
|
"dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), |
|
|
} |
|
|
|
|
|
|
|
|
km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy()) |
|
|
cluster_labels = km_all.labels_ |
|
|
|
|
|
cluster_0_mask = cluster_labels == 0 |
|
|
cluster_1_mask = cluster_labels == 1 |
|
|
cluster_2_mask = cluster_labels == 2 |
|
|
|
|
|
samples = self.dataset.cpu().numpy() |
|
|
|
|
|
cluster_0_data = samples[cluster_0_mask] |
|
|
cluster_1_data = samples[cluster_1_mask] |
|
|
cluster_2_data = samples[cluster_2_mask] |
|
|
|
|
|
self.metric_samples_dataloaders = [ |
|
|
DataLoader( |
|
|
torch.tensor(cluster_2_data, dtype=torch.float32), |
|
|
batch_size=cluster_2_data.shape[0], |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
), |
|
|
DataLoader( |
|
|
torch.tensor(cluster_0_data, dtype=torch.float32), |
|
|
batch_size=cluster_0_data.shape[0], |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
), |
|
|
|
|
|
DataLoader( |
|
|
torch.tensor(cluster_1_data, dtype=torch.float32), |
|
|
batch_size=cluster_1_data.shape[0], |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
), |
|
|
] |
|
|
|
|
|
def train_dataloader(self): |
|
|
combined_loaders = { |
|
|
"train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
|
|
"metric_samples": CombinedLoader( |
|
|
self.metric_samples_dataloaders, mode="min_size" |
|
|
), |
|
|
} |
|
|
return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
|
|
|
def val_dataloader(self): |
|
|
combined_loaders = { |
|
|
"val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
|
|
"metric_samples": CombinedLoader( |
|
|
self.metric_samples_dataloaders, mode="min_size" |
|
|
), |
|
|
} |
|
|
|
|
|
return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
|
|
|
|
|
|
|
|
|
def test_dataloader(self): |
|
|
combined_loaders = { |
|
|
"test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), |
|
|
"metric_samples": CombinedLoader( |
|
|
self.metric_samples_dataloaders, mode="min_size" |
|
|
), |
|
|
} |
|
|
|
|
|
return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
|
|
|
def get_manifold_proj(self, points): |
|
|
"""Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" |
|
|
return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) |
|
|
|
|
|
@staticmethod |
|
|
def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): |
|
|
""" |
|
|
Apply local smoothing based on k-nearest neighbors in the full dataset |
|
|
This replaces the plane projection for 2D manifold regularization |
|
|
""" |
|
|
points_np = x.detach().cpu().numpy() |
|
|
_, idx = tree.query(points_np, k=k) |
|
|
nearest_pts = dataset[idx] |
|
|
|
|
|
|
|
|
dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
|
|
weights = torch.exp(-dists / temp) |
|
|
weights = weights / weights.sum(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
smoothed = (weights * nearest_pts).sum(dim=1) |
|
|
|
|
|
|
|
|
alpha = 0.3 |
|
|
return (1 - alpha) * x + alpha * smoothed |
|
|
|
|
|
def get_timepoint_data(self): |
|
|
"""Return data organized by timepoints for visualization""" |
|
|
return { |
|
|
't0': self.coords_t0, |
|
|
't1': self.coords_t1, |
|
|
'time_labels': self.time_labels |
|
|
} |
|
|
|
|
|
def get_datamodule(): |
|
|
datamodule = DrugResponseDataModule(args) |
|
|
datamodule.setup(stage="fit") |
|
|
return datamodule |