Spaces:
Runtime error
Runtime error
| import json | |
| import math | |
| import os | |
| import random | |
| from dataclasses import dataclass | |
| import cv2 | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn.functional as F | |
| from scipy.spatial.transform import Rotation as Rot | |
| from scipy.spatial.transform import Slerp | |
| from torch.utils.data import DataLoader, Dataset, IterableDataset | |
| from tqdm import tqdm | |
| import threestudio | |
| from threestudio import register | |
| from threestudio.utils.config import parse_structured | |
| from threestudio.utils.ops import get_mvp_matrix, get_ray_directions, get_rays | |
| from threestudio.utils.typing import * | |
| def convert_pose(C2W): | |
| flip_yz = torch.eye(4) | |
| flip_yz[1, 1] = -1 | |
| flip_yz[2, 2] = -1 | |
| C2W = torch.matmul(C2W, flip_yz) | |
| return C2W | |
| def convert_proj(K, H, W, near, far): | |
| return [ | |
| [2 * K[0, 0] / W, -2 * K[0, 1] / W, (W - 2 * K[0, 2]) / W, 0], | |
| [0, -2 * K[1, 1] / H, (H - 2 * K[1, 2]) / H, 0], | |
| [0, 0, (-far - near) / (far - near), -2 * far * near / (far - near)], | |
| [0, 0, -1, 0], | |
| ] | |
| def inter_pose(pose_0, pose_1, ratio): | |
| pose_0 = pose_0.detach().cpu().numpy() | |
| pose_1 = pose_1.detach().cpu().numpy() | |
| pose_0 = np.linalg.inv(pose_0) | |
| pose_1 = np.linalg.inv(pose_1) | |
| rot_0 = pose_0[:3, :3] | |
| rot_1 = pose_1[:3, :3] | |
| rots = Rot.from_matrix(np.stack([rot_0, rot_1])) | |
| key_times = [0, 1] | |
| slerp = Slerp(key_times, rots) | |
| rot = slerp(ratio) | |
| pose = np.diag([1.0, 1.0, 1.0, 1.0]) | |
| pose = pose.astype(np.float32) | |
| pose[:3, :3] = rot.as_matrix() | |
| pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] | |
| pose = np.linalg.inv(pose) | |
| return pose | |
| class MultiviewsDataModuleConfig: | |
| dataroot: str = "" | |
| train_downsample_resolution: int = 4 | |
| eval_downsample_resolution: int = 4 | |
| train_data_interval: int = 1 | |
| eval_data_interval: int = 1 | |
| batch_size: int = 1 | |
| eval_batch_size: int = 1 | |
| camera_layout: str = "around" | |
| camera_distance: float = -1 | |
| eval_interpolation: Optional[Tuple[int, int, int]] = None # (0, 1, 30) | |
| class MultiviewIterableDataset(IterableDataset): | |
| def __init__(self, cfg: Any) -> None: | |
| super().__init__() | |
| self.cfg: MultiviewsDataModuleConfig = cfg | |
| assert self.cfg.batch_size == 1 | |
| scale = self.cfg.train_downsample_resolution | |
| camera_dict = json.load( | |
| open(os.path.join(self.cfg.dataroot, "transforms.json"), "r") | |
| ) | |
| assert camera_dict["camera_model"] == "OPENCV" | |
| frames = camera_dict["frames"] | |
| frames = frames[:: self.cfg.train_data_interval] | |
| frames_proj = [] | |
| frames_c2w = [] | |
| frames_position = [] | |
| frames_direction = [] | |
| frames_img = [] | |
| self.frame_w = frames[0]["w"] // scale | |
| self.frame_h = frames[0]["h"] // scale | |
| threestudio.info("Loading frames...") | |
| self.n_frames = len(frames) | |
| c2w_list = [] | |
| for frame in tqdm(frames): | |
| extrinsic: Float[Tensor, "4 4"] = torch.as_tensor( | |
| frame["transform_matrix"], dtype=torch.float32 | |
| ) | |
| c2w = extrinsic | |
| c2w_list.append(c2w) | |
| c2w_list = torch.stack(c2w_list, dim=0) | |
| if self.cfg.camera_layout == "around": | |
| c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0) | |
| elif self.cfg.camera_layout == "front": | |
| assert self.cfg.camera_distance > 0 | |
| c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0) | |
| z_vector = torch.zeros(c2w_list.shape[0], 3, 1) | |
| z_vector[:, 2, :] = -1 | |
| rot_z_vector = c2w_list[:, :3, :3] @ z_vector | |
| rot_z_vector = torch.mean(rot_z_vector, dim=0).unsqueeze(0) | |
| c2w_list[:, :3, 3] -= rot_z_vector[:, :, 0] * self.cfg.camera_distance | |
| else: | |
| raise ValueError( | |
| f"Unknown camera layout {self.cfg.camera_layout}. Now support only around and front." | |
| ) | |
| for idx, frame in tqdm(enumerate(frames)): | |
| intrinsic: Float[Tensor, "4 4"] = torch.eye(4) | |
| intrinsic[0, 0] = frame["fl_x"] / scale | |
| intrinsic[1, 1] = frame["fl_y"] / scale | |
| intrinsic[0, 2] = frame["cx"] / scale | |
| intrinsic[1, 2] = frame["cy"] / scale | |
| frame_path = os.path.join(self.cfg.dataroot, frame["file_path"]) | |
| img = cv2.imread(frame_path)[:, :, ::-1].copy() | |
| img = cv2.resize(img, (self.frame_w, self.frame_h)) | |
| img: Float[Tensor, "H W 3"] = torch.FloatTensor(img) / 255 | |
| frames_img.append(img) | |
| direction: Float[Tensor, "H W 3"] = get_ray_directions( | |
| self.frame_h, | |
| self.frame_w, | |
| (intrinsic[0, 0], intrinsic[1, 1]), | |
| (intrinsic[0, 2], intrinsic[1, 2]), | |
| use_pixel_centers=False, | |
| ) | |
| c2w = c2w_list[idx] | |
| camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1) | |
| near = 0.1 | |
| far = 1000.0 | |
| proj = convert_proj(intrinsic, self.frame_h, self.frame_w, near, far) | |
| proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj) | |
| frames_proj.append(proj) | |
| frames_c2w.append(c2w) | |
| frames_position.append(camera_position) | |
| frames_direction.append(direction) | |
| threestudio.info("Loaded frames.") | |
| self.frames_proj: Float[Tensor, "B 4 4"] = torch.stack(frames_proj, dim=0) | |
| self.frames_c2w: Float[Tensor, "B 4 4"] = torch.stack(frames_c2w, dim=0) | |
| self.frames_position: Float[Tensor, "B 3"] = torch.stack(frames_position, dim=0) | |
| self.frames_direction: Float[Tensor, "B H W 3"] = torch.stack( | |
| frames_direction, dim=0 | |
| ) | |
| self.frames_img: Float[Tensor, "B H W 3"] = torch.stack(frames_img, dim=0) | |
| self.rays_o, self.rays_d = get_rays( | |
| self.frames_direction, self.frames_c2w, keepdim=True | |
| ) | |
| self.mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix( | |
| self.frames_c2w, self.frames_proj | |
| ) | |
| self.light_positions: Float[Tensor, "B 3"] = torch.zeros_like( | |
| self.frames_position | |
| ) | |
| def __iter__(self): | |
| while True: | |
| yield {} | |
| def collate(self, batch): | |
| index = torch.randint(0, self.n_frames, (1,)).item() | |
| return { | |
| "index": index, | |
| "rays_o": self.rays_o[index : index + 1], | |
| "rays_d": self.rays_d[index : index + 1], | |
| "mvp_mtx": self.mvp_mtx[index : index + 1], | |
| "c2w": self.frames_c2w[index : index + 1], | |
| "camera_positions": self.frames_position[index : index + 1], | |
| "light_positions": self.light_positions[index : index + 1], | |
| "gt_rgb": self.frames_img[index : index + 1], | |
| "height": self.frame_h, | |
| "width": self.frame_w, | |
| } | |
| class MultiviewDataset(Dataset): | |
| def __init__(self, cfg: Any, split: str) -> None: | |
| super().__init__() | |
| self.cfg: MultiviewsDataModuleConfig = cfg | |
| assert self.cfg.eval_batch_size == 1 | |
| scale = self.cfg.eval_downsample_resolution | |
| camera_dict = json.load( | |
| open(os.path.join(self.cfg.dataroot, "transforms.json"), "r") | |
| ) | |
| assert camera_dict["camera_model"] == "OPENCV" | |
| frames = camera_dict["frames"] | |
| frames = frames[:: self.cfg.eval_data_interval] | |
| frames_proj = [] | |
| frames_c2w = [] | |
| frames_position = [] | |
| frames_direction = [] | |
| frames_img = [] | |
| self.frame_w = frames[0]["w"] // scale | |
| self.frame_h = frames[0]["h"] // scale | |
| threestudio.info("Loading frames...") | |
| self.n_frames = len(frames) | |
| c2w_list = [] | |
| for frame in tqdm(frames): | |
| extrinsic: Float[Tensor, "4 4"] = torch.as_tensor( | |
| frame["transform_matrix"], dtype=torch.float32 | |
| ) | |
| c2w = extrinsic | |
| c2w_list.append(c2w) | |
| c2w_list = torch.stack(c2w_list, dim=0) | |
| if self.cfg.camera_layout == "around": | |
| c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0) | |
| elif self.cfg.camera_layout == "front": | |
| assert self.cfg.camera_distance > 0 | |
| c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0) | |
| z_vector = torch.zeros(c2w_list.shape[0], 3, 1) | |
| z_vector[:, 2, :] = -1 | |
| rot_z_vector = c2w_list[:, :3, :3] @ z_vector | |
| rot_z_vector = torch.mean(rot_z_vector, dim=0).unsqueeze(0) | |
| c2w_list[:, :3, 3] -= rot_z_vector[:, :, 0] * self.cfg.camera_distance | |
| else: | |
| raise ValueError( | |
| f"Unknown camera layout {self.cfg.camera_layout}. Now support only around and front." | |
| ) | |
| if not (self.cfg.eval_interpolation is None): | |
| idx0 = self.cfg.eval_interpolation[0] | |
| idx1 = self.cfg.eval_interpolation[1] | |
| eval_nums = self.cfg.eval_interpolation[2] | |
| frame = frames[idx0] | |
| intrinsic: Float[Tensor, "4 4"] = torch.eye(4) | |
| intrinsic[0, 0] = frame["fl_x"] / scale | |
| intrinsic[1, 1] = frame["fl_y"] / scale | |
| intrinsic[0, 2] = frame["cx"] / scale | |
| intrinsic[1, 2] = frame["cy"] / scale | |
| for ratio in np.linspace(0, 1, eval_nums): | |
| img: Float[Tensor, "H W 3"] = torch.zeros( | |
| (self.frame_h, self.frame_w, 3) | |
| ) | |
| frames_img.append(img) | |
| direction: Float[Tensor, "H W 3"] = get_ray_directions( | |
| self.frame_h, | |
| self.frame_w, | |
| (intrinsic[0, 0], intrinsic[1, 1]), | |
| (intrinsic[0, 2], intrinsic[1, 2]), | |
| use_pixel_centers=False, | |
| ) | |
| c2w = torch.FloatTensor( | |
| inter_pose(c2w_list[idx0], c2w_list[idx1], ratio) | |
| ) | |
| camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1) | |
| near = 0.1 | |
| far = 1000.0 | |
| proj = convert_proj(intrinsic, self.frame_h, self.frame_w, near, far) | |
| proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj) | |
| frames_proj.append(proj) | |
| frames_c2w.append(c2w) | |
| frames_position.append(camera_position) | |
| frames_direction.append(direction) | |
| else: | |
| for idx, frame in tqdm(enumerate(frames)): | |
| intrinsic: Float[Tensor, "4 4"] = torch.eye(4) | |
| intrinsic[0, 0] = frame["fl_x"] / scale | |
| intrinsic[1, 1] = frame["fl_y"] / scale | |
| intrinsic[0, 2] = frame["cx"] / scale | |
| intrinsic[1, 2] = frame["cy"] / scale | |
| frame_path = os.path.join(self.cfg.dataroot, frame["file_path"]) | |
| img = cv2.imread(frame_path)[:, :, ::-1].copy() | |
| img = cv2.resize(img, (self.frame_w, self.frame_h)) | |
| img: Float[Tensor, "H W 3"] = torch.FloatTensor(img) / 255 | |
| frames_img.append(img) | |
| direction: Float[Tensor, "H W 3"] = get_ray_directions( | |
| self.frame_h, | |
| self.frame_w, | |
| (intrinsic[0, 0], intrinsic[1, 1]), | |
| (intrinsic[0, 2], intrinsic[1, 2]), | |
| use_pixel_centers=False, | |
| ) | |
| c2w = c2w_list[idx] | |
| camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1) | |
| near = 0.1 | |
| far = 1000.0 | |
| K = intrinsic | |
| proj = [ | |
| [ | |
| 2 * K[0, 0] / self.frame_w, | |
| -2 * K[0, 1] / self.frame_w, | |
| (self.frame_w - 2 * K[0, 2]) / self.frame_w, | |
| 0, | |
| ], | |
| [ | |
| 0, | |
| -2 * K[1, 1] / self.frame_h, | |
| (self.frame_h - 2 * K[1, 2]) / self.frame_h, | |
| 0, | |
| ], | |
| [ | |
| 0, | |
| 0, | |
| (-far - near) / (far - near), | |
| -2 * far * near / (far - near), | |
| ], | |
| [0, 0, -1, 0], | |
| ] | |
| proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj) | |
| frames_proj.append(proj) | |
| frames_c2w.append(c2w) | |
| frames_position.append(camera_position) | |
| frames_direction.append(direction) | |
| threestudio.info("Loaded frames.") | |
| self.frames_proj: Float[Tensor, "B 4 4"] = torch.stack(frames_proj, dim=0) | |
| self.frames_c2w: Float[Tensor, "B 4 4"] = torch.stack(frames_c2w, dim=0) | |
| self.frames_position: Float[Tensor, "B 3"] = torch.stack(frames_position, dim=0) | |
| self.frames_direction: Float[Tensor, "B H W 3"] = torch.stack( | |
| frames_direction, dim=0 | |
| ) | |
| self.frames_img: Float[Tensor, "B H W 3"] = torch.stack(frames_img, dim=0) | |
| self.rays_o, self.rays_d = get_rays( | |
| self.frames_direction, self.frames_c2w, keepdim=True | |
| ) | |
| self.mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix( | |
| self.frames_c2w, self.frames_proj | |
| ) | |
| self.light_positions: Float[Tensor, "B 3"] = torch.zeros_like( | |
| self.frames_position | |
| ) | |
| def __len__(self): | |
| return self.frames_proj.shape[0] | |
| def __getitem__(self, index): | |
| return { | |
| "index": index, | |
| "rays_o": self.rays_o[index], | |
| "rays_d": self.rays_d[index], | |
| "mvp_mtx": self.mvp_mtx[index], | |
| "c2w": self.frames_c2w[index], | |
| "camera_positions": self.frames_position[index], | |
| "light_positions": self.light_positions[index], | |
| "gt_rgb": self.frames_img[index], | |
| } | |
| def __iter__(self): | |
| while True: | |
| yield {} | |
| def collate(self, batch): | |
| batch = torch.utils.data.default_collate(batch) | |
| batch.update({"height": self.frame_h, "width": self.frame_w}) | |
| return batch | |
| class MultiviewDataModule(pl.LightningDataModule): | |
| cfg: MultiviewsDataModuleConfig | |
| def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(MultiviewsDataModuleConfig, cfg) | |
| def setup(self, stage=None) -> None: | |
| if stage in [None, "fit"]: | |
| self.train_dataset = MultiviewIterableDataset(self.cfg) | |
| if stage in [None, "fit", "validate"]: | |
| self.val_dataset = MultiviewDataset(self.cfg, "val") | |
| if stage in [None, "test", "predict"]: | |
| self.test_dataset = MultiviewDataset(self.cfg, "test") | |
| def prepare_data(self): | |
| pass | |
| def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: | |
| return DataLoader( | |
| dataset, | |
| num_workers=1, # type: ignore | |
| batch_size=batch_size, | |
| collate_fn=collate_fn, | |
| ) | |
| def train_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate | |
| ) | |
| # return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate) | |
| def test_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate | |
| ) | |
| def predict_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate | |
| ) | |