`position_ids` buffer is corrupted when loading with `transformers>=5.0` (meta device loading)

#30
by AdrienB134 - opened

Hi! When loading this model with transformers>=5.0, the position_ids non-persistent buffer ends up containing uninitialized garbage data, causing IndexError crashes or silently wrong embeddings.

Quick explanation: Transformers v5 now creates models on the meta device (shape-only, no real memory) before loading checkpoint weights. Non-persistent buffers like position_ids aren't in the checkpoint, so they never get real values. The _init_weights method in this model's modeling script doesn't re-initialize them either, since this wasn't needed before v5.

Native HF models (BERT, etc.) were patched to re-derive buffers in _init_weights. See for example BERT's fix. This model needs the same treatment.

Reproduces on all Python versions (3.12, 3.13, 3.14) with transformers==5.1.0. Does NOT reproduce with transformers<5.

The fix

Add buffer re-initialization using the transformers v5 init helpers:

import transformers.initialization as init

MyPreTrainedModel(PreTrainedModel):
    def _init_weights(self, module):
        if isinstance(module, MyModuleWhereBufferLives):
            init.copy_(module.non_persistent_buffers, torch.arange(10))

User-side workaround (no model code change needed)

import torch
from transformers import AutoModel

model = AutoModel.from_pretrained("Alibaba-NLP/gte-multilingual-base", trust_remote_code=True)
model.eval()

# Fix corrupted position_ids after loading
embeddings = model.embeddings
max_pos = embeddings.position_ids.size(0)
embeddings.register_buffer("position_ids", torch.arange(max_pos), persistent=True)

Sign up or log in to comment