| import os | |
| import random | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from transformers import AdamW | |
| from colbert.utils.runs import Run | |
| from colbert.utils.amp import MixedPrecisionManager | |
| from colbert.training.lazy_batcher import LazyBatcher | |
| from colbert.training.eager_batcher import EagerBatcher | |
| from colbert.parameters import DEVICE | |
| from colbert.modeling.colbert import ColBERT | |
| from colbert.utils.utils import print_message | |
| from colbert.training.utils import print_progress, manage_checkpoints | |
| def train(args): | |
| random.seed(12345) | |
| np.random.seed(12345) | |
| torch.manual_seed(12345) | |
| if args.distributed: | |
| torch.cuda.manual_seed_all(12345) | |
| if args.distributed: | |
| assert args.bsize % args.nranks == 0, (args.bsize, args.nranks) | |
| assert args.accumsteps == 1 | |
| args.bsize = args.bsize // args.nranks | |
| print("Using args.bsize =", args.bsize, "(per process) and args.accumsteps =", args.accumsteps) | |
| if args.lazy: | |
| reader = LazyBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks) | |
| else: | |
| reader = EagerBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks) | |
| if args.rank not in [-1, 0]: | |
| torch.distributed.barrier() | |
| colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased', | |
| query_maxlen=args.query_maxlen, | |
| doc_maxlen=args.doc_maxlen, | |
| dim=args.dim, | |
| similarity_metric=args.similarity, | |
| mask_punctuation=args.mask_punctuation) | |
| if args.checkpoint is not None: | |
| assert args.resume_optimizer is False, "TODO: This would mean reload optimizer too." | |
| print_message(f"#> Starting from checkpoint {args.checkpoint} -- but NOT the optimizer!") | |
| checkpoint = torch.load(args.checkpoint, map_location='cpu') | |
| try: | |
| colbert.load_state_dict(checkpoint['model_state_dict']) | |
| except: | |
| print_message("[WARNING] Loading checkpoint with strict=False") | |
| colbert.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| if args.rank == 0: | |
| torch.distributed.barrier() | |
| colbert = colbert.to(DEVICE) | |
| colbert.train() | |
| if args.distributed: | |
| colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[args.rank], | |
| output_device=args.rank, | |
| find_unused_parameters=True) | |
| optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=args.lr, eps=1e-8) | |
| optimizer.zero_grad() | |
| amp = MixedPrecisionManager(args.amp) | |
| criterion = nn.CrossEntropyLoss() | |
| labels = torch.zeros(args.bsize, dtype=torch.long, device=DEVICE) | |
| start_time = time.time() | |
| train_loss = 0.0 | |
| start_batch_idx = 0 | |
| if args.resume: | |
| assert args.checkpoint is not None | |
| start_batch_idx = checkpoint['batch'] | |
| reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize']) | |
| for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps), reader): | |
| this_batch_loss = 0.0 | |
| for queries, passages in BatchSteps: | |
| with amp.context(): | |
| scores = colbert(queries, passages).view(2, -1).permute(1, 0) | |
| loss = criterion(scores, labels[:scores.size(0)]) | |
| loss = loss / args.accumsteps | |
| if args.rank < 1: | |
| print_progress(scores) | |
| amp.backward(loss) | |
| train_loss += loss.item() | |
| this_batch_loss += loss.item() | |
| amp.step(colbert, optimizer) | |
| if args.rank < 1: | |
| avg_loss = train_loss / (batch_idx+1) | |
| num_examples_seen = (batch_idx - start_batch_idx) * args.bsize * args.nranks | |
| elapsed = float(time.time() - start_time) | |
| log_to_mlflow = (batch_idx % 20 == 0) | |
| Run.log_metric('train/avg_loss', avg_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) | |
| Run.log_metric('train/batch_loss', this_batch_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) | |
| Run.log_metric('train/examples', num_examples_seen, step=batch_idx, log_to_mlflow=log_to_mlflow) | |
| Run.log_metric('train/throughput', num_examples_seen / elapsed, step=batch_idx, log_to_mlflow=log_to_mlflow) | |
| print_message(batch_idx, avg_loss) | |
| manage_checkpoints(args, colbert, optimizer, batch_idx+1) | |