| | import torch |
| |
|
| |
|
| | def is_hopper_gpu(): |
| | if torch.cuda.is_available(): |
| | device_capability = torch.cuda.get_device_capability(0) |
| | major, minor = device_capability |
| | return major == 9 |
| | return False |
| |
|
| |
|
| | def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): |
| | """ |
| | Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. |
| | |
| | Args: |
| | head_dim (int): Size of the head dimension. |
| | block_size (int): Size of the block in the attention matrix. |
| | is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. |
| | |
| | Returns: |
| | tuple: (num_warps, num_stages) recommended values. |
| | """ |
| | |
| | head_large = head_dim > 64 |
| | block_large = block_size > 64 |
| |
|
| | if is_hopper_gpu: |
| | |
| | if head_large and block_large: |
| | num_warps = 8 |
| | num_stages = 3 |
| | elif head_large or block_large: |
| | num_warps = 4 |
| | num_stages = 3 |
| | else: |
| | num_warps = 2 |
| | num_stages = 2 |
| | else: |
| | |
| | if head_large and block_large: |
| | num_warps = 8 |
| | num_stages = 3 |
| | elif head_large or block_large: |
| | num_warps = 8 |
| | num_stages = 3 |
| | else: |
| | num_warps = 2 |
| | num_stages = 2 |
| | return num_warps, num_stages |
| |
|