import sys from model.trainer import Trainer sys.path.insert(0, '.') import torch import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.nn.parallel import gather import torch.optim.lr_scheduler import dataset.dataset as myDataLoader import dataset.Transforms as myTransforms from model.metric_tool import ConfuseMatrixMeter from model.utils import BCEDiceLoss, init_seed from PIL import Image import os import time import numpy as np from argparse import ArgumentParser from tqdm import tqdm @torch.no_grad() def validate(args, val_loader, model, save_masks=False): model.eval() # 确保所有BatchNorm层使用全局统计量 for m in model.modules(): if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): m.track_running_stats = True m.eval() salEvalVal = ConfuseMatrixMeter(n_class=2) epoch_loss = [] if save_masks: mask_dir = f"{args.savedir}/pred_masks" os.makedirs(mask_dir, exist_ok=True) print(f"Saving prediction masks to: {mask_dir}") pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating") for batch_idx, batched_inputs in pbar: img, target = batched_inputs # 获取当前batch的所有文件名 batch_file_names = val_loader.sampler.data_source.file_list[ batch_idx * args.batch_size : (batch_idx + 1) * args.batch_size ] pre_img = img[:, 0:3] post_img = img[:, 3:6] if args.onGPU: pre_img = pre_img.cuda() post_img = post_img.cuda() target = target.cuda() target = target.float() output = model(pre_img, post_img) loss = BCEDiceLoss(output, target) pred = (output > 0.5).long() if save_masks: pred_np = pred.cpu().numpy().astype(np.uint8) print(f"\nDebug - Batch {batch_idx}: {len(batch_file_names)} files, Mask shape: {pred_np.shape}") try: for i in range(pred_np.shape[0]): if i >= len(batch_file_names): # 防止文件名不足 print(f"Warning: Missing filename for mask {i}, using default") base_name = f"batch_{batch_idx}_mask_{i}" else: base_name = os.path.splitext(os.path.basename(batch_file_names[i]))[0] single_mask = pred_np[i, 0] # 获取(1, 256, 256)中的(256, 256) if single_mask.ndim != 2: raise ValueError(f"Invalid mask shape: {single_mask.shape}") mask_path = f"{mask_dir}/{base_name}_pred.png" Image.fromarray(single_mask * 255).save(mask_path) print(f"Saved: {mask_path}") except Exception as e: print(f"\nError saving batch {batch_idx}: {str(e)}") print(f"Current mask shape: {single_mask.shape if 'single_mask' in locals() else 'N/A'}") print(f"Current file: {base_name if 'base_name' in locals() else 'N/A'}") if args.onGPU and torch.cuda.device_count() > 1: pred = gather(pred, 0, dim=0) f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target.cpu().numpy()) epoch_loss.append(loss.item()) pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'F1': f"{f1:.4f}"}) average_loss = sum(epoch_loss) / len(epoch_loss) scores = salEvalVal.get_scores() return average_loss, scores def ValidateSegmentation(args): """完整的验证流程主函数""" # 初始化设置 os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) torch.backends.cudnn.benchmark = True init_seed(args.seed) # 固定随机种子保证可重复性 # 模型路径设置 args.savedir = os.path.join(args.savedir, f"{args.file_root}_iter_{args.max_steps}_lr_{args.lr}") os.makedirs(args.savedir, exist_ok=True) # 数据集路径配置 dataset_mapping = { 'LEVIR': './levir_cd_256', 'WHU': './whu_cd_256', 'CLCD': './clcd_256', 'SYSU': './sysu_256', 'OSCD': './oscd_256' } args.file_root = dataset_mapping.get(args.file_root, args.file_root) # 初始化模型 model = Trainer(args.model_type).float() if args.onGPU: model = model.cuda() # 数据预处理 - 保持与训练时验证集相同的预处理 mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485] std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229] valDataset = myTransforms.Compose([ myTransforms.Normalize(mean=mean, std=std), myTransforms.Scale(args.inWidth, args.inHeight), myTransforms.ToTensor() ]) # 数据加载 test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset) testLoader = torch.utils.data.DataLoader( test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) # 日志设置 logFileLoc = os.path.join(args.savedir, args.logFile) logger = open(logFileLoc, 'a' if os.path.exists(logFileLoc) else 'w') if not os.path.exists(logFileLoc): logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s" % ('Epoch', 'Kappa', 'IoU', 'F1', 'Recall', 'Precision', 'OA')) logger.flush() # 加载最佳模型 model_file_name = os.path.join(args.savedir, 'best_model.pth') if not os.path.exists(model_file_name): raise FileNotFoundError(f"Model file not found: {model_file_name}") state_dict = torch.load(model_file_name) model.load_state_dict(state_dict) print(f"Loaded model from {model_file_name}") # 执行验证 loss_test, score_test = validate(args, testLoader, model, save_masks=args.save_masks) # 输出结果 print("\nTest Results:") print(f"Loss: {loss_test:.4f}") print(f"Kappa: {score_test['Kappa']:.4f}") print(f"IoU: {score_test['IoU']:.4f}") print(f"F1: {score_test['F1']:.4f}") print(f"Recall: {score_test['recall']:.4f}") print(f"Precision: {score_test['precision']:.4f}") print(f"OA: {score_test['OA']:.4f}") # 记录日志 logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % ('Test', score_test['Kappa'], score_test['IoU'], score_test['F1'], score_test['recall'], score_test['precision'], score_test['OA'])) logger.close() if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--file_root', default="LEVIR", help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD') parser.add_argument('--inWidth', type=int, default=256, help='Width of input image') parser.add_argument('--inHeight', type=int, default=256, help='Height of input image') parser.add_argument('--max_steps', type=int, default=80000, help='Max. number of iterations (for path naming)') parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers') parser.add_argument('--model_type', type=str, default='small', help='Model type | tiny | small') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for validation') parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate (for path naming)') parser.add_argument('--seed', type=int, default=16, help='Random seed for reproducibility') parser.add_argument('--savedir', default='./results', help='Base directory to save results') parser.add_argument('--logFile', default='testLog.txt', help='File to save validation logs') parser.add_argument('--onGPU', default=True, type=lambda x: (str(x).lower() == 'true'), help='Run on GPU if True') parser.add_argument('--gpu_id', type=int, default=0, help='GPU device id') parser.add_argument('--save_masks', action='store_true', help='Save predicted masks to disk') args = parser.parse_args() print('Validation with args:') print(args) ValidateSegmentation(args)