Fix: Align mask preparation logic for Eager attention to prevent corrupted outputs

#4
by haibo8 - opened

This PR fixes a bug in the eager_attention_forward path where the attention mask was being incorrectly processed, leading to garbled model outputs (nonsense text).
Issue:
The model's generation logic defines the attention mask such that 0 represents "no mask" and -inf represents "mask". However, the eager implementation currently utilizes _prepare_4d_causal_attention_mask. This utility function is designed to convert binary masks (1 for keep, 0 for mask) into additive masks (0 for keep, -inf for mask).

Since the input mask is already in the additive format (0/-inf), passing it through this function results in a logic inversion or incorrect value mapping. Given that the subsequent eager_attention_forward uses additive logic (attn_weights + attention_mask), this mismatch causes the attention mechanism to fail.

Changes:

  • Replaced _prepare_4d_causal_attention_mask with _prepare_4d_causal_attention_mask_for_sdpa in the forward pass.
  • This ensures consistency across eager, sdpa, and flex attention backends, as all of them expect additive mask logic.

Impact:
Restores correct text generation when using the eager attention backend.

utdawn changed pull request status to merged

Sign up or log in to comment