Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| import torch | |
| import threestudio | |
| from threestudio.systems.base import BaseLift3DSystem | |
| from threestudio.utils.ops import binary_cross_entropy, dot | |
| from threestudio.utils.typing import * | |
| from gaussiansplatting.gaussian_renderer import render | |
| from gaussiansplatting.scene import Scene, GaussianModel | |
| from gaussiansplatting.arguments import ModelParams, PipelineParams, get_combined_args,OptimizationParams | |
| from gaussiansplatting.scene.cameras import Camera | |
| from argparse import ArgumentParser, Namespace | |
| import os | |
| from pathlib import Path | |
| from plyfile import PlyData, PlyElement | |
| from gaussiansplatting.utils.sh_utils import SH2RGB | |
| from gaussiansplatting.scene.gaussian_model import BasicPointCloud | |
| import numpy as np | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config as diffusion_from_config_shape | |
| from shap_e.models.download import load_model, load_config | |
| from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget | |
| from shap_e.util.notebooks import decode_latent_mesh | |
| import io | |
| from PIL import Image | |
| import open3d as o3d | |
| def load_ply(path,save_path): | |
| C0 = 0.28209479177387814 | |
| def SH2RGB(sh): | |
| return sh * C0 + 0.5 | |
| plydata = PlyData.read(path) | |
| xyz = np.stack((np.asarray(plydata.elements[0]["x"]), | |
| np.asarray(plydata.elements[0]["y"]), | |
| np.asarray(plydata.elements[0]["z"])), axis=1) | |
| features_dc = np.zeros((xyz.shape[0], 3, 1)) | |
| features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) | |
| features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) | |
| features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) | |
| color = SH2RGB(features_dc[:,:,0]) | |
| point_cloud = o3d.geometry.PointCloud() | |
| point_cloud.points = o3d.utility.Vector3dVector(xyz) | |
| point_cloud.colors = o3d.utility.Vector3dVector(color) | |
| o3d.io.write_point_cloud(save_path, point_cloud) | |
| def storePly(path, xyz, rgb): | |
| # Define the dtype for the structured array | |
| dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), | |
| ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), | |
| ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] | |
| normals = np.zeros_like(xyz) | |
| elements = np.empty(xyz.shape[0], dtype=dtype) | |
| attributes = np.concatenate((xyz, normals, rgb), axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| # Create the PlyData object and write to file | |
| vertex_element = PlyElement.describe(elements, 'vertex') | |
| ply_data = PlyData([vertex_element]) | |
| ply_data.write(path) | |
| def fetchPly(path): | |
| plydata = PlyData.read(path) | |
| vertices = plydata['vertex'] | |
| positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T | |
| colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 | |
| normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T | |
| return BasicPointCloud(points=positions, colors=colors, normals=normals) | |
| class GaussianDreamer(BaseLift3DSystem): | |
| class Config(BaseLift3DSystem.Config): | |
| radius: float = 4 | |
| sh_degree: int = 0 | |
| cfg: Config | |
| def configure(self) -> None: | |
| self.radius = self.cfg.radius | |
| self.sh_degree =self.cfg.sh_degree | |
| self.gaussian = GaussianModel(sh_degree = self.sh_degree) | |
| bg_color = [1, 1, 1] if False else [0, 0, 0] | |
| self.background_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
| def save_gif_to_file(self,images, output_file): | |
| with io.BytesIO() as writer: | |
| images[0].save( | |
| writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0 | |
| ) | |
| writer.seek(0) | |
| with open(output_file, 'wb') as file: | |
| file.write(writer.read()) | |
| def shape(self): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| xm = load_model('transmitter', device=device) | |
| model = load_model('text300M', device=device) | |
| model.load_state_dict(torch.load('./load/shapE_finetuned_with_330kdata.pth', map_location=device)['model_state_dict']) | |
| diffusion = diffusion_from_config_shape(load_config('diffusion')) | |
| batch_size = 1 | |
| guidance_scale = 15.0 | |
| prompt = str(self.cfg.prompt_processor.prompt) | |
| print('prompt',prompt) | |
| latents = sample_latents( | |
| batch_size=batch_size, | |
| model=model, | |
| diffusion=diffusion, | |
| guidance_scale=guidance_scale, | |
| model_kwargs=dict(texts=[prompt] * batch_size), | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=True, | |
| use_karras=True, | |
| karras_steps=64, | |
| sigma_min=1e-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| render_mode = 'nerf' # you can change this to 'stf' | |
| size = 256 # this is the size of the renders; higher values take longer to render. | |
| cameras = create_pan_cameras(size, device) | |
| self.shapeimages = decode_latent_images(xm, latents[0], cameras, rendering_mode=render_mode) | |
| pc = decode_latent_mesh(xm, latents[0]).tri_mesh() | |
| skip = 4 | |
| coords = pc.verts | |
| rgb = np.concatenate([pc.vertex_channels['R'][:,None],pc.vertex_channels['G'][:,None],pc.vertex_channels['B'][:,None]],axis=1) | |
| coords = coords[::skip] | |
| rgb = rgb[::skip] | |
| self.num_pts = coords.shape[0] | |
| point_cloud = o3d.geometry.PointCloud() | |
| point_cloud.points = o3d.utility.Vector3dVector(coords) | |
| point_cloud.colors = o3d.utility.Vector3dVector(rgb) | |
| self.point_cloud = point_cloud | |
| return coords,rgb,0.4 | |
| def add_points(self,coords,rgb): | |
| pcd_by3d = o3d.geometry.PointCloud() | |
| pcd_by3d.points = o3d.utility.Vector3dVector(np.array(coords)) | |
| bbox = pcd_by3d.get_axis_aligned_bounding_box() | |
| np.random.seed(0) | |
| num_points = 1000000 | |
| points = np.random.uniform(low=np.asarray(bbox.min_bound), high=np.asarray(bbox.max_bound), size=(num_points, 3)) | |
| kdtree = o3d.geometry.KDTreeFlann(pcd_by3d) | |
| points_inside = [] | |
| color_inside= [] | |
| for point in points: | |
| _, idx, _ = kdtree.search_knn_vector_3d(point, 1) | |
| nearest_point = np.asarray(pcd_by3d.points)[idx[0]] | |
| if np.linalg.norm(point - nearest_point) < 0.01: # 这个阈值可能需要调整 | |
| points_inside.append(point) | |
| color_inside.append(rgb[idx[0]]+0.2*np.random.random(3)) | |
| all_coords = np.array(points_inside) | |
| all_rgb = np.array(color_inside) | |
| all_coords = np.concatenate([all_coords,coords],axis=0) | |
| all_rgb = np.concatenate([all_rgb,rgb],axis=0) | |
| return all_coords,all_rgb | |
| def pcb(self): | |
| # Since this data set has no colmap data, we start with random points | |
| coords,rgb,scale = self.shape() | |
| bound= self.radius*scale | |
| all_coords,all_rgb = self.add_points(coords,rgb) | |
| pcd = BasicPointCloud(points=all_coords *bound, colors=all_rgb, normals=np.zeros((self.num_pts, 3))) | |
| return pcd | |
| def forward(self, batch: Dict[str, Any],renderbackground = None) -> Dict[str, Any]: | |
| if renderbackground is None: | |
| renderbackground = self.background_tensor | |
| images = [] | |
| depths = [] | |
| self.viewspace_point_list = [] | |
| for id in range(batch['c2w_3dgs'].shape[0]): | |
| viewpoint_cam = Camera(c2w = batch['c2w_3dgs'][id],FoVy = batch['fovy'][id],height = batch['height'],width = batch['width']) | |
| render_pkg = render(viewpoint_cam, self.gaussian, self.pipe, renderbackground) | |
| image, viewspace_point_tensor, _, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] | |
| self.viewspace_point_list.append(viewspace_point_tensor) | |
| if id == 0: | |
| self.radii = radii | |
| else: | |
| self.radii = torch.max(radii,self.radii) | |
| depth = render_pkg["depth_3dgs"] | |
| depth = depth.permute(1, 2, 0) | |
| image = image.permute(1, 2, 0) | |
| images.append(image) | |
| depths.append(depth) | |
| images = torch.stack(images, 0) | |
| depths = torch.stack(depths, 0) | |
| self.visibility_filter = self.radii>0.0 | |
| render_pkg["comp_rgb"] = images | |
| render_pkg["depth"] = depths | |
| render_pkg["opacity"] = depths / (depths.max() + 1e-5) | |
| return { | |
| **render_pkg, | |
| } | |
| def on_fit_start(self) -> None: | |
| super().on_fit_start() | |
| # only used in training | |
| self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)( | |
| self.cfg.prompt_processor | |
| ) | |
| self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) | |
| def training_step(self, batch, batch_idx): | |
| self.gaussian.update_learning_rate(self.true_global_step) | |
| if self.true_global_step > 500: | |
| self.guidance.set_min_max_steps(min_step_percent=0.02, max_step_percent=0.55) | |
| self.gaussian.update_learning_rate(self.true_global_step) | |
| out = self(batch) | |
| prompt_utils = self.prompt_processor() | |
| images = out["comp_rgb"] | |
| guidance_eval = (self.true_global_step % 200 == 0) | |
| # guidance_eval = False | |
| guidance_out = self.guidance( | |
| images, prompt_utils, **batch, rgb_as_latents=False,guidance_eval=guidance_eval | |
| ) | |
| loss = 0.0 | |
| loss = loss + guidance_out['loss_sds'] *self.C(self.cfg.loss['lambda_sds']) | |
| loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean() | |
| self.log("train/loss_sparsity", loss_sparsity) | |
| loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity) | |
| opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) | |
| loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped) | |
| self.log("train/loss_opaque", loss_opaque) | |
| loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque) | |
| if guidance_eval: | |
| self.guidance_evaluation_save( | |
| out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]], | |
| guidance_out["eval"], | |
| ) | |
| for name, value in self.cfg.loss.items(): | |
| self.log(f"train_params/{name}", self.C(value)) | |
| return {"loss": loss} | |
| def on_before_optimizer_step(self, optimizer): | |
| with torch.no_grad(): | |
| if self.true_global_step < 900: # 15000 | |
| viewspace_point_tensor_grad = torch.zeros_like(self.viewspace_point_list[0]) | |
| for idx in range(len(self.viewspace_point_list)): | |
| viewspace_point_tensor_grad = viewspace_point_tensor_grad + self.viewspace_point_list[idx].grad | |
| # Keep track of max radii in image-space for pruning | |
| self.gaussian.max_radii2D[self.visibility_filter] = torch.max(self.gaussian.max_radii2D[self.visibility_filter], self.radii[self.visibility_filter]) | |
| self.gaussian.add_densification_stats(viewspace_point_tensor_grad, self.visibility_filter) | |
| if self.true_global_step > 300 and self.true_global_step % 100 == 0: # 500 100 | |
| size_threshold = 20 if self.true_global_step > 500 else None # 3000 | |
| self.gaussian.densify_and_prune(0.0002 , 0.05, self.cameras_extent, size_threshold) | |
| def validation_step(self, batch, batch_idx): | |
| out = self(batch) | |
| self.save_image_grid( | |
| f"it{self.true_global_step}-{batch['index'][0]}.png", | |
| ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": batch["rgb"][0], | |
| "kwargs": {"data_format": "HWC"}, | |
| } | |
| ] | |
| if "rgb" in batch | |
| else [] | |
| ) | |
| + [ | |
| { | |
| "type": "rgb", | |
| "img": out["comp_rgb"][0], | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| ] | |
| + ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": out["comp_normal"][0], | |
| "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, | |
| } | |
| ] | |
| if "comp_normal" in out | |
| else [] | |
| ), | |
| name="validation_step", | |
| step=self.true_global_step, | |
| ) | |
| # save_path = self.get_save_path(f"it{self.true_global_step}-val.ply") | |
| # self.gaussian.save_ply(save_path) | |
| # load_ply(save_path,self.get_save_path(f"it{self.true_global_step}-val-color.ply")) | |
| def on_validation_epoch_end(self): | |
| pass | |
| def test_step(self, batch, batch_idx): | |
| only_rgb = True | |
| bg_color = [1, 1, 1] if False else [0, 0, 0] | |
| testbackground_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
| out = self(batch,testbackground_tensor) | |
| if only_rgb: | |
| self.save_image_grid( | |
| f"it{self.true_global_step}-test/{batch['index'][0]}.png", | |
| ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": batch["rgb"][0], | |
| "kwargs": {"data_format": "HWC"}, | |
| } | |
| ] | |
| if "rgb" in batch | |
| else [] | |
| ) | |
| + [ | |
| { | |
| "type": "rgb", | |
| "img": out["comp_rgb"][0], | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| ] | |
| + ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": out["comp_normal"][0], | |
| "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, | |
| } | |
| ] | |
| if "comp_normal" in out | |
| else [] | |
| ), | |
| name="test_step", | |
| step=self.true_global_step, | |
| ) | |
| else: | |
| self.save_image_grid( | |
| f"it{self.true_global_step}-test/{batch['index'][0]}.png", | |
| ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": batch["rgb"][0], | |
| "kwargs": {"data_format": "HWC"}, | |
| } | |
| ] | |
| if "rgb" in batch | |
| else [] | |
| ) | |
| + [ | |
| { | |
| "type": "rgb", | |
| "img": out["comp_rgb"][0], | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| ] | |
| + ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": out["comp_normal"][0], | |
| "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, | |
| } | |
| ] | |
| if "comp_normal" in out | |
| else [] | |
| ) | |
| + ( | |
| [ | |
| { | |
| "type": "grayscale", | |
| "img": out["depth"][0], | |
| "kwargs": {}, | |
| } | |
| ] | |
| if "depth" in out | |
| else [] | |
| ) | |
| + [ | |
| { | |
| "type": "grayscale", | |
| "img": out["opacity"][0, :, :, 0], | |
| "kwargs": {"cmap": None, "data_range": (0, 1)}, | |
| }, | |
| ], | |
| name="test_step", | |
| step=self.true_global_step, | |
| ) | |
| def on_test_epoch_end(self): | |
| self.save_img_sequence( | |
| f"it{self.true_global_step}-test", | |
| f"it{self.true_global_step}-test", | |
| "(\d+)\.png", | |
| save_format="mp4", | |
| fps=30, | |
| name="test", | |
| step=self.true_global_step, | |
| ) | |
| save_path = self.get_save_path(f"last.ply") | |
| self.gaussian.save_ply(save_path) | |
| # self.pointefig.savefig(self.get_save_path("pointe.png")) | |
| o3d.io.write_point_cloud(self.get_save_path("shape.ply"), self.point_cloud) | |
| self.save_gif_to_file(self.shapeimages, self.get_save_path("shape.gif")) | |
| load_ply(save_path,self.get_save_path(f"it{self.true_global_step}-test-color.ply")) | |
| def configure_optimizers(self): | |
| self.parser = ArgumentParser(description="Training script parameters") | |
| opt = OptimizationParams(self.parser) | |
| point_cloud = self.pcb() | |
| self.cameras_extent = 4.0 | |
| self.gaussian.create_from_pcd(point_cloud, self.cameras_extent) | |
| self.pipe = PipelineParams(self.parser) | |
| self.gaussian.training_setup(opt) | |
| ret = { | |
| "optimizer": self.gaussian.optimizer, | |
| } | |
| return ret |