Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # Adopted from https://github.com/zhuzilin/ring-flash-attention. | |
| # Implementation refers to Ring Attention Paper: https://arxiv.org/abs/2310.01889 | |
| import torch | |
| from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward | |
| from .utils import RingComm, update_out_and_lse | |
| try: | |
| from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse | |
| except: | |
| from .utils import flatten_varlen_lse, unflatten_varlen_lse | |
| def ring_flash_attn_varlen_forward( | |
| process_group, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens, | |
| max_seqlen, | |
| softmax_scale, | |
| dropout_p=0, | |
| causal=True, | |
| window_size=(-1, -1), | |
| alibi_slopes=None, | |
| deterministic=False, | |
| ): | |
| comm = RingComm(process_group) | |
| out = None | |
| lse = None | |
| next_k, next_v = None, None | |
| for step in range(comm.world_size): | |
| if step + 1 != comm.world_size: | |
| next_k: torch.Tensor = comm.send_recv(k) | |
| next_v: torch.Tensor = comm.send_recv(v) | |
| comm.commit() | |
| if not causal or step <= comm.rank: | |
| block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens, | |
| cu_seqlens, | |
| max_seqlen, | |
| max_seqlen, | |
| dropout_p, | |
| softmax_scale, | |
| causal=causal and step == 0, | |
| window_size=window_size, | |
| alibi_slopes=alibi_slopes, | |
| return_softmax=True and dropout_p > 0, | |
| block_table=None, | |
| ) | |
| block_lse = flatten_varlen_lse(block_lse, cu_seqlens=cu_seqlens) | |
| out, lse = update_out_and_lse(out, lse, block_out, block_lse) | |
| if step + 1 != comm.world_size: | |
| comm.wait() | |
| k = next_k | |
| v = next_v | |
| out = out.to(q.dtype) | |
| lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) | |
| return out, lse | |
| def ring_flash_attn_varlen_backward( | |
| process_group, | |
| dout, | |
| q, | |
| k, | |
| v, | |
| out, | |
| softmax_lse, | |
| cu_seqlens, | |
| max_seqlen, | |
| softmax_scale, | |
| dropout_p=0, | |
| causal=True, | |
| window_size=(-1, -1), | |
| alibi_slopes=None, | |
| deterministic=False, | |
| ): | |
| kv_comm = RingComm(process_group) | |
| d_kv_comm = RingComm(process_group) | |
| dq, dk, dv = None, None, None | |
| next_dk, next_dv = None, None | |
| block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) | |
| block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) | |
| block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) | |
| next_dk, next_dv = None, None | |
| next_k, next_v = None, None | |
| for step in range(kv_comm.world_size): | |
| if step + 1 != kv_comm.world_size: | |
| next_k = kv_comm.send_recv(k) | |
| next_v = kv_comm.send_recv(v) | |
| kv_comm.commit() | |
| if step <= kv_comm.rank or not causal: | |
| bwd_causal = causal and step == 0 | |
| _flash_attn_varlen_backward( | |
| dout, | |
| q, | |
| k, | |
| v, | |
| out, | |
| softmax_lse, | |
| block_dq_buffer, | |
| block_dk_buffer, | |
| block_dv_buffer, | |
| cu_seqlens, | |
| cu_seqlens, | |
| max_seqlen, | |
| max_seqlen, | |
| dropout_p, | |
| softmax_scale, | |
| bwd_causal, | |
| window_size, | |
| alibi_slopes, | |
| deterministic, | |
| rng_state=None, | |
| ) | |
| if dq is None: | |
| dq = block_dq_buffer.to(torch.float32) | |
| dk = block_dk_buffer.to(torch.float32) | |
| dv = block_dv_buffer.to(torch.float32) | |
| else: | |
| dq += block_dq_buffer | |
| d_kv_comm.wait() | |
| dk = block_dk_buffer + next_dk | |
| dv = block_dv_buffer + next_dv | |
| elif step != 0: | |
| d_kv_comm.wait() | |
| dk = next_dk | |
| dv = next_dv | |
| if step + 1 != kv_comm.world_size: | |
| kv_comm.wait() | |
| k = next_k | |
| v = next_v | |
| next_dk = d_kv_comm.send_recv(dk) | |
| next_dv = d_kv_comm.send_recv(dv) | |
| d_kv_comm.commit() | |
| d_kv_comm.wait() | |
| return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) | |
| class RingFlashAttnVarlenFunc(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| q, | |
| k, | |
| v, | |
| cu_seqlens, | |
| max_seqlen, | |
| dropout_p, | |
| softmax_scale, | |
| causal, | |
| window_size, | |
| alibi_slopes, | |
| deterministic, | |
| return_softmax, | |
| group, | |
| ): | |
| if softmax_scale is None: | |
| softmax_scale = q.shape[-1] ** (-0.5) | |
| assert alibi_slopes is None | |
| k = k.contiguous() | |
| v = v.contiguous() | |
| out, softmax_lse = ring_flash_attn_varlen_forward( | |
| group, | |
| q, | |
| k, | |
| v, | |
| cu_seqlens, | |
| max_seqlen, | |
| softmax_scale=softmax_scale, | |
| dropout_p=dropout_p, | |
| causal=causal, | |
| window_size=window_size, | |
| alibi_slopes=alibi_slopes, | |
| deterministic=False, | |
| ) | |
| # this should be out_padded | |
| ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) | |
| ctx.max_seqlen = max_seqlen | |
| ctx.dropout_p = dropout_p | |
| ctx.softmax_scale = softmax_scale | |
| ctx.causal = causal | |
| ctx.window_size = window_size | |
| ctx.alibi_slopes = alibi_slopes | |
| ctx.deterministic = deterministic | |
| ctx.group = group | |
| return out if not return_softmax else (out, softmax_lse, None) | |
| def backward(ctx, dout, *args): | |
| q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors | |
| dq, dk, dv = ring_flash_attn_varlen_backward( | |
| ctx.group, | |
| dout, | |
| q, | |
| k, | |
| v, | |
| out, | |
| softmax_lse, | |
| cu_seqlens, | |
| ctx.max_seqlen, | |
| softmax_scale=ctx.softmax_scale, | |
| dropout_p=ctx.dropout_p, | |
| causal=ctx.causal, | |
| window_size=ctx.window_size, | |
| alibi_slopes=ctx.alibi_slopes, | |
| deterministic=ctx.deterministic, | |
| ) | |
| return dq, dk, dv, None, None, None, None, None, None, None, None, None, None | |
| def ring_flash_attn_varlen_qkvpacked_func( | |
| qkv, | |
| cu_seqlens, | |
| max_seqlen, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| window_size=(-1, -1), # -1 means infinite context window | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| group=None, | |
| ): | |
| return RingFlashAttnVarlenFunc.apply( | |
| qkv[:, 0], | |
| qkv[:, 1], | |
| qkv[:, 2], | |
| cu_seqlens, | |
| max_seqlen, | |
| dropout_p, | |
| softmax_scale, | |
| causal, | |
| window_size, | |
| alibi_slopes, | |
| deterministic, | |
| return_attn_probs, | |
| group, | |
| ) | |
| def ring_flash_attn_varlen_kvpacked_func( | |
| q, | |
| kv, | |
| cu_seqlens, | |
| max_seqlen, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| window_size=(-1, -1), # -1 means infinite context window | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| group=None, | |
| ): | |
| return RingFlashAttnVarlenFunc.apply( | |
| q, | |
| kv[:, 0], | |
| kv[:, 1], | |
| cu_seqlens, | |
| max_seqlen, | |
| dropout_p, | |
| softmax_scale, | |
| causal, | |
| window_size, | |
| alibi_slopes, | |
| deterministic, | |
| return_attn_probs, | |
| group, | |
| ) | |
| def ring_flash_attn_varlen_func( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens, | |
| max_seqlen, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| window_size=(-1, -1), # -1 means infinite context window | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| group=None, | |
| ): | |
| return RingFlashAttnVarlenFunc.apply( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens, | |
| max_seqlen, | |
| dropout_p, | |
| softmax_scale, | |
| causal, | |
| window_size, | |
| alibi_slopes, | |
| deterministic, | |
| return_attn_probs, | |
| group, | |
| ) | |