import os import subprocess import argparse from tqdm.contrib.concurrent import process_map from functools import partial def run_retrieve(src_dir, json_name, data_root): if 'StorageFurniture' not in src_dir and 'Table' not in src_dir: data_root = '../acd_data/merged-data' fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root] try: subprocess.run(fn_call, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: print(f'Error from run_retrieve: {src_dir}') print(f'Error: {e}') return ' '.join(fn_call) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, required=True, help="path to the experiment folder") parser.add_argument("--json_name", type=str, default="object.json", help="name of the json file") parser.add_argument("--gt_data_root", type=str, default="../data", help="path to the ground truth data") parser.add_argument("--max_workers", type=int, default=6, help="number of images to render for each object") args = parser.parse_args() assert os.path.exists(args.src), f"Src path does not exist: {args.src}" assert os.path.exists(args.gt_data_root), f"GT data root does not exist: {args.gt_data_root}" exp_path = args.src # len_root = len(exp) print('----------- Retrieve Part Meshes -----------') src_dirs = [] # exps = os.listdir(root) # for exp in exps: # exp_path = os.path.join(root, exp) for model_id in os.listdir(exp_path): model_id_path = os.path.join(exp_path, model_id) # print(model_id_path) if '.' in model_id: continue for model_id_id in os.listdir(model_id_path): if '.' not in model_id_id: model_id_id_path = os.path.join(model_id_path, model_id_id) for json_file in os.listdir(model_id_id_path): if json_file.endswith(args.json_name): if os.path.exists(os.path.join(model_id_id_path, 'object.ply')): print(f"Found {model_id_id_path} with object.ply") else: # run_retrieve(model_id_id_path, json_name=args.json_name, data_root=args.gt_data_root) src_dirs.append(model_id_id_path) print(len(src_dirs), model_id_id_path) # for dirpath, dirname, fnames in os.walk(root): # for fname in fnames: # if fname.endswith(args.json_name): # src_dirs.append(dirpath) # save the relative directory path # print(root) print(f"Found {len(src_dirs)} jsons to retrieve part meshes") # print(src_dirs) # import ipdb # ipdb.set_trace() # for src_dir in src_dirs: # print(src_dir) # command = run_retrieve(src_dir, json_name=args.json_name, data_root=args.gt_data_root) # command_file = open('retrieve_commands.sh', 'a') # command_file.write(command + '\n') # command_file.close() process_map(partial(run_retrieve, json_name=args.json_name, data_root=args.gt_data_root), src_dirs, max_workers=6, chunksize=1)