| import os | |
| import ujson | |
| from functools import partial | |
| from colbert.utils.utils import print_message | |
| from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples | |
| from colbert.utils.runs import Run | |
| class LazyBatcher(): | |
| def __init__(self, args, rank=0, nranks=1): | |
| self.bsize, self.accumsteps = args.bsize, args.accumsteps | |
| self.query_tokenizer = QueryTokenizer(args.query_maxlen) | |
| self.doc_tokenizer = DocTokenizer(args.doc_maxlen) | |
| self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer) | |
| self.position = 0 | |
| self.triples = self._load_triples(args.triples, rank, nranks) | |
| self.queries = self._load_queries(args.queries) | |
| self.collection = self._load_collection(args.collection) | |
| def _load_triples(self, path, rank, nranks): | |
| """ | |
| NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling. | |
| In particular, each subset is perfectly represented in every batch! However, since we never | |
| repeat passes over the data, we never repeat any particular triple, and the split across | |
| nodes is random (since the underlying file is pre-shuffled), there's no concern here. | |
| """ | |
| print_message("#> Loading triples...") | |
| triples = [] | |
| with open(path) as f: | |
| for line_idx, line in enumerate(f): | |
| if line_idx % nranks == rank: | |
| qid, pos, neg = ujson.loads(line) | |
| triples.append((qid, pos, neg)) | |
| return triples | |
| def _load_queries(self, path): | |
| print_message("#> Loading queries...") | |
| queries = {} | |
| with open(path) as f: | |
| for line in f: | |
| qid, query = line.strip().split('\t') | |
| qid = int(qid) | |
| queries[qid] = query | |
| return queries | |
| def _load_collection(self, path): | |
| print_message("#> Loading collection...") | |
| collection = [] | |
| with open(path) as f: | |
| for line_idx, line in enumerate(f): | |
| pid, passage, title, *_ = line.strip().split('\t') | |
| assert pid == 'id' or int(pid) == line_idx | |
| passage = title + ' | ' + passage | |
| collection.append(passage) | |
| return collection | |
| def __iter__(self): | |
| return self | |
| def __len__(self): | |
| return len(self.triples) | |
| def __next__(self): | |
| offset, endpos = self.position, min(self.position + self.bsize, len(self.triples)) | |
| self.position = endpos | |
| if offset + self.bsize > len(self.triples): | |
| raise StopIteration | |
| queries, positives, negatives = [], [], [] | |
| for position in range(offset, endpos): | |
| query, pos, neg = self.triples[position] | |
| query, pos, neg = self.queries[query], self.collection[pos], self.collection[neg] | |
| queries.append(query) | |
| positives.append(pos) | |
| negatives.append(neg) | |
| return self.collate(queries, positives, negatives) | |
| def collate(self, queries, positives, negatives): | |
| assert len(queries) == len(positives) == len(negatives) == self.bsize | |
| return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps) | |
| def skip_to_batch(self, batch_idx, intended_batch_size): | |
| Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') | |
| self.position = intended_batch_size * batch_idx | |