Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import argparse | |
| def num_floating_point_operations(args): | |
| def calculate_layer_counts(): | |
| """Calculate the number of attention, Mamba, and MLP layers.""" | |
| if args.hybrid_override_pattern: | |
| counts = {"M": 0, "*": 0, "-": 0} | |
| for layer_type in args.hybrid_override_pattern: | |
| if layer_type in counts: | |
| counts[layer_type] += 1 | |
| return counts["*"], counts["M"], counts["-"] | |
| else: | |
| num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio) | |
| num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio) | |
| num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers | |
| return num_attn_layers, num_mamba_layers, num_mlp_layers | |
| def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False): | |
| """Calculate FLOPs for an MLP layer.""" | |
| scale_factor = 3.0 / 2.0 if swiglu else 1.0 | |
| return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 | |
| def attn_layer_flops( | |
| batch_size, | |
| seq_len, | |
| hidden_size, | |
| num_heads, | |
| gqa=True, | |
| gqa_groups=8, | |
| kv_channels=None, | |
| ): | |
| """Calculate FLOPs for an attention layer.""" | |
| p = (kv_channels * num_heads / hidden_size) if kv_channels else 1 | |
| g = gqa_groups if gqa else num_heads | |
| return ( | |
| 4 | |
| * batch_size | |
| * seq_len | |
| * hidden_size | |
| * p | |
| * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2)) | |
| ) | |
| def mamba_layer_flops( | |
| batch_size, seq_len, hidden_size, state_dim=16, head_dim=64, num_groups=1 | |
| ): | |
| """Calculate FLOPs for a Mamba layer.""" | |
| # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels, | |
| # but small percent of overall layer flops | |
| d_in = 2 * hidden_size | |
| nheads = d_in // head_dim | |
| return ( | |
| ( | |
| 2 | |
| * batch_size | |
| * seq_len | |
| * hidden_size | |
| * (2 * d_in + 2 * num_groups * state_dim + nheads) | |
| ) # in_proj | |
| + (7 * batch_size * seq_len * d_in * state_dim) # scan | |
| + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj | |
| ) | |
| def hybrid_flops( | |
| batch_size, | |
| seq_len, | |
| hidden_size, | |
| num_attn_layers, | |
| num_mamba_layers, | |
| num_mlp_layers, | |
| mamba_state_dim=128, | |
| mamba_head_dim=64, | |
| mamba_num_groups=8, | |
| num_attn_heads=32, | |
| gqa=True, | |
| gqa_groups=8, | |
| kv_channels=None, | |
| mlp_expansion=4.0, | |
| swiglu=False, | |
| vocab_size=256000, | |
| ): | |
| """Calculate total FLOPs for the hybrid model.""" | |
| flops_fwd = ( | |
| num_attn_layers | |
| * attn_layer_flops( | |
| batch_size, | |
| seq_len, | |
| hidden_size, | |
| num_attn_heads, | |
| gqa, | |
| gqa_groups, | |
| kv_channels, | |
| ) | |
| + num_mlp_layers | |
| * mlp_layer_flops(batch_size, seq_len, hidden_size, mlp_expansion, swiglu) | |
| + num_mamba_layers | |
| * mamba_layer_flops( | |
| batch_size, | |
| seq_len, | |
| hidden_size, | |
| mamba_state_dim, | |
| mamba_head_dim, | |
| mamba_num_groups, | |
| ) | |
| + ( | |
| 2 * batch_size * seq_len * hidden_size * vocab_size | |
| ) # logits computation | |
| ) | |
| return flops_fwd * 3 | |
| def transformer_flops(): | |
| """Calculate FLOPs for a standard Transformer model.""" | |
| # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods. | |
| # Attention projection size. | |
| query_projection_size = args.kv_channels * args.num_attention_heads | |
| query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size | |
| # Group Query Attention. | |
| if not args.group_query_attention: | |
| args.num_query_groups = args.num_attention_heads | |
| # MoE. | |
| if args.num_experts is None: | |
| # Every Transformer MLP is dense. | |
| num_dense_layers = args.num_layers | |
| num_moe_layers = 0 | |
| num_experts_routed_to = 0 | |
| last_layer_is_moe = 0 | |
| else: | |
| # Calculate number of dense and MoE Transformer MLPs. | |
| if isinstance(args.moe_layer_freq, int): | |
| moe_layer_pattern = [ | |
| 1 if (i % args.moe_layer_freq == 0) else 0 | |
| for i in range(args.num_layers) | |
| ] | |
| elif isinstance(args.moe_layer_freq, list): | |
| moe_layer_pattern = args.moe_layer_freq | |
| else: | |
| raise RuntimeError("Illegal --moe-layer-freq argument provided!") | |
| assert len(moe_layer_pattern) == args.num_layers, ( | |
| f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " | |
| f"expected {args.num_layers}, " | |
| f"current moe layer pattern: {args.moe_layer_freq}" | |
| ) | |
| num_moe_layers = sum( | |
| moe_layer_pattern | |
| ) # Number of 1s in `moe_layer_pattern`. | |
| num_dense_layers = args.num_layers - num_moe_layers | |
| num_experts_routed_to = args.moe_router_topk | |
| last_layer_is_moe = moe_layer_pattern[-1] | |
| if args.mtp_num_layers is not None: | |
| mtp_num_layers = args.mtp_num_layers | |
| num_moe_layers += last_layer_is_moe * mtp_num_layers | |
| num_dense_layers += (1 - last_layer_is_moe) * mtp_num_layers | |
| num_layers = args.num_layers + mtp_num_layers | |
| else: | |
| mtp_num_layers = 0 | |
| num_layers = args.num_layers | |
| moe_ffn_hidden_size = ( | |
| args.moe_ffn_hidden_size | |
| if args.moe_ffn_hidden_size is not None | |
| else args.ffn_hidden_size | |
| ) | |
| shared_expert_ffn_hidden_size = ( | |
| 0 | |
| if args.moe_shared_expert_intermediate_size is None | |
| else args.moe_shared_expert_intermediate_size | |
| ) | |
| # SwiGLU. | |
| gated_linear_multiplier = 3 / 2 if args.swiglu else 1 | |
| # The 12x term below comes from the following factors; for more details, see | |
| # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473. | |
| # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass, | |
| # backward wgrad [weight gradient], backward dgrad [data gradient]). | |
| # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model | |
| # architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM | |
| # in MLP layer). | |
| # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations. | |
| expansion_factor = 3 * 2 * 2 | |
| if args.multi_latent_attention: | |
| assert not args.group_query_attention | |
| """ | |
| Basic arithmetic | |
| let B is batch size, s is seq_len, h is embedding dim, | |
| for one self_attnetion block (prenorm is not included) | |
| qkv projection: 6Bsh^2 | |
| attn: 2Bs^2h | |
| attn over value: 2Bs^2h | |
| oproj: 2Bsh^2 | |
| references | |
| https://arxiv.org/abs/2305.10403 | |
| https://arxiv.org/abs/2205.05198 | |
| """ | |
| ## MLA | |
| if args.q_lora_rank is None: | |
| q_term = ( | |
| args.hidden_size | |
| * args.num_attention_heads | |
| * (args.qk_head_dim + args.qk_pos_emb_head_dim) | |
| ) | |
| else: | |
| q_term = args.q_lora_rank * ( | |
| args.hidden_size | |
| + args.num_attention_heads | |
| * (args.qk_head_dim + args.qk_pos_emb_head_dim) | |
| + 1 | |
| ) | |
| self_attn_term = ( | |
| 3 | |
| * 2 # fwd(1) + bwd(2) *FMA | |
| * num_layers | |
| * ( | |
| ## q lora + rope + q norm | |
| q_term | |
| ## kv lora + rope + kv norm | |
| + args.kv_lora_rank | |
| * ( | |
| args.hidden_size | |
| + args.num_attention_heads | |
| * (args.qk_head_dim + args.v_head_dim) | |
| + 1 | |
| ) | |
| + args.hidden_size * args.qk_pos_emb_head_dim | |
| ## o proj | |
| + (args.num_attention_heads * args.v_head_dim) * args.hidden_size | |
| ## core attn | |
| + args.seq_length | |
| * ( | |
| args.num_attention_heads | |
| * (args.qk_head_dim + args.qk_pos_emb_head_dim) | |
| ) | |
| / 2 | |
| + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 | |
| ) | |
| ) | |
| else: | |
| ## MHA or GQA | |
| self_attn_term = ( | |
| expansion_factor | |
| * num_layers | |
| * args.hidden_size | |
| * args.hidden_size | |
| * ( | |
| ( | |
| 1 | |
| + (args.num_query_groups / args.num_attention_heads) | |
| # # Only half of the attention matrix is non-zero and needs to be multiplied with V. | |
| + (args.seq_length / args.hidden_size / 2) | |
| ) | |
| * query_projection_to_hidden_size_ratio | |
| ) | |
| ) | |
| total_floating_point_operations = ( | |
| args.batch_size | |
| * args.seq_length | |
| * ( | |
| # MLP | |
| expansion_factor | |
| * num_layers | |
| * args.hidden_size | |
| * ( | |
| # dense layer (deepseek v2, v3 style) | |
| (args.ffn_hidden_size * gated_linear_multiplier) | |
| * (num_dense_layers / num_layers) | |
| # routed experts | |
| + ( | |
| moe_ffn_hidden_size | |
| * num_experts_routed_to | |
| * gated_linear_multiplier | |
| ) | |
| * (num_moe_layers / num_layers) | |
| # Shared Experts. | |
| + (shared_expert_ffn_hidden_size * gated_linear_multiplier) | |
| * (num_moe_layers / num_layers) | |
| ) | |
| # Self Attention | |
| + self_attn_term | |
| # MTP norms and proj | |
| + 3 | |
| * 2 | |
| * mtp_num_layers | |
| * ( | |
| # MTP eh norm + final nrom | |
| 3 * args.hidden_size | |
| # MTH eh proj | |
| + 2 * args.hidden_size * args.hidden_size | |
| ) | |
| # Logit. | |
| + 3 | |
| * 2 | |
| * args.hidden_size | |
| * args.padded_vocab_size | |
| * (mtp_num_layers + 1) | |
| ) | |
| ) | |
| return total_floating_point_operations | |
| # Main entrypoint for FLOPs calculation. | |
| if args.is_hybrid_model: | |
| # Calculate the number of each type of layer. | |
| num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts() | |
| # Compute hybrid model FLOPs. | |
| return hybrid_flops( | |
| batch_size=args.batch_size, | |
| seq_len=args.seq_length, | |
| hidden_size=args.hidden_size, | |
| num_attn_layers=num_attn_layers, | |
| num_mamba_layers=num_mamba_layers, | |
| num_mlp_layers=num_mlp_layers, | |
| mamba_state_dim=args.mamba_state_dim, | |
| mamba_head_dim=args.mamba_head_dim, | |
| mamba_num_groups=args.mamba_num_groups, | |
| num_attn_heads=args.num_attention_heads, | |
| gqa=args.group_query_attention, | |
| gqa_groups=args.num_query_groups, | |
| kv_channels=args.kv_channels, | |
| mlp_expansion=args.ffn_hidden_size / args.hidden_size, | |
| swiglu=args.swiglu, | |
| vocab_size=args.padded_vocab_size, | |
| ) | |
| else: | |
| # Compute standard Transformer model FLOPs. | |
| return transformer_flops() | |
| def calculate_flops(args): | |
| model_flops = num_floating_point_operations(args) | |
| flops_per_token = model_flops / (args.batch_size * args.seq_length) | |
| print(f"FLOPs Per Iteration: {model_flops}\nFLOPs Per Token: {flops_per_token}") | |
| return model_flops | |
| def calculate_mfu(model_flops, *, iter_elapsed_time, num_p800_cards): | |
| assert ( | |
| model_flops and iter_elapsed_time and num_p800_cards | |
| ), "Iter elapsed time and P800 cards must be provided" | |
| mfu = model_flops / (iter_elapsed_time * num_p800_cards * 3.5e14) | |
| print(f"MFU P800 bf16: {mfu:.2%}") | |
| def calculate_mfu_web( is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size, | |
| ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels, | |
| num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size, | |
| multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim, | |
| mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards | |
| ): | |
| is_hybrid_model = True if is_hybrid_model == "True" else False | |
| group_query_attention = True if group_query_attention == "True" else False | |
| swiglu = True if swiglu == "True" else False | |
| multi_latent_attention = True if multi_latent_attention == "True" else False | |
| ''' | |
| 为了直接调用calculate_flops(args)接口,这里将参数直接打包 | |
| ''' | |
| class parameter: | |
| def __init__(self, | |
| is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size, | |
| ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels, | |
| num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size, | |
| multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim, | |
| mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards, | |
| hybrid_override_pattern=None): | |
| self.is_hybrid_model = is_hybrid_model | |
| self.group_query_attention = group_query_attention | |
| self.swiglu = swiglu | |
| self.num_layers = num_layers | |
| self.hidden_size = hidden_size | |
| self.ffn_hidden_size = ffn_hidden_size | |
| self.padded_vocab_size = padded_vocab_size | |
| self.num_attention_heads = num_attention_heads | |
| self.kv_channels = kv_channels | |
| self.num_experts = num_experts | |
| self.moe_layer_freq = moe_layer_freq | |
| self.moe_router_topk = moe_router_topk | |
| self.moe_ffn_hidden_size = moe_ffn_hidden_size | |
| self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size | |
| self.multi_latent_attention = multi_latent_attention | |
| self.q_lora_rank = q_lora_rank | |
| self.kv_lora_rank = kv_lora_rank | |
| self.qk_head_dim = qk_head_dim | |
| self.v_head_dim = v_head_dim | |
| self.qk_pos_emb_head_dim = qk_pos_emb_head_dim | |
| self.mtp_num_layers = mtp_num_layers | |
| self.seq_length = seq_length | |
| self.batch_size = batch_size | |
| self.iter_elapsed_time = iter_elapsed_time | |
| self.num_p800_cards = num_p800_cards | |
| self.hybrid_override_pattern = hybrid_override_pattern | |
| mfu_parameter = parameter(is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size, | |
| ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels, | |
| num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size, | |
| multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim, | |
| mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards, | |
| hybrid_override_pattern=None) | |
| model_flops = num_floating_point_operations(mfu_parameter) | |
| flops_per_token = model_flops / (batch_size * seq_length) | |
| print(f"FLOPs Per Iteration: {model_flops}\nFLOPs Per Token: {flops_per_token}") | |
| assert ( | |
| model_flops and iter_elapsed_time and num_p800_cards | |
| ), "Iter elapsed time and P800 cards must be provided" | |
| mfu = model_flops / (iter_elapsed_time * num_p800_cards * 3.5e14) | |
| print(f"MFU P800 bf16: {mfu:.2%}") | |
| return model_flops, flops_per_token, "{:.2f}%".format(mfu * 100) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| args = parser.parse_args() | |
| # Standard Transformer config | |
| args.is_hybrid_model = False | |
| args.group_query_attention = False | |
| args.swiglu = True | |
| args.num_layers = 61 | |
| args.hidden_size = 7168 | |
| args.ffn_hidden_size = 18432 | |
| args.padded_vocab_size = 100002 | |
| args.num_attention_heads = 128 | |
| args.kv_channels = 128 | |
| # MoE config | |
| args.num_experts = 256 | |
| args.moe_layer_freq = 1 | |
| args.moe_router_topk = 8 | |
| args.moe_ffn_hidden_size = 2048 | |
| args.moe_shared_expert_intermediate_size = 2048 | |
| # MLA config | |
| args.multi_latent_attention = True | |
| args.q_lora_rank = 1536 | |
| args.kv_lora_rank = 512 | |
| args.qk_head_dim = 128 | |
| args.v_head_dim = 128 | |
| args.qk_pos_emb_head_dim = 64 | |
| # MTP config | |
| args.mtp_num_layers = 1 | |
| # Data config | |
| args.seq_length = 4096 | |
| args.batch_size = 1024 | |
| # mfu config | |
| args.iter_elapsed_time = 100 | |
| args.num_p800_cards = 512 | |
| #calculate_mfu(calculate_flops(args), iter_elapsed_time=args.iter_elapsed_time, num_p800_cards=args.num_p800_cards) | |
| with gr.Blocks(title="Compute MFU") as demo: | |
| gr.Markdown("## Compute MFU") | |
| with gr.Group() as custom_group: | |
| gr.Markdown("Standard Transformer config:") | |
| with gr.Row(): | |
| is_hybrid_model = gr.Dropdown(["True", "False"], | |
| label="hybrid model", | |
| value="True" if args.is_hybrid_model else "False") | |
| group_query_attention = gr.Dropdown(["True", "False"], | |
| label="group query attention", | |
| value="True" if args.group_query_attention else "False") | |
| swiglu = gr.Dropdown(["True", "False"], | |
| label="swiglu", | |
| value="True" if args.swiglu else "False") | |
| num_layers = gr.Number(label="num layers", value=args.num_layers, precision=0) | |
| hidden_size = gr.Number(label="hidden size", value=args.hidden_size, precision=0) | |
| ffn_hidden_size = gr.Number(label="ffn hidden size", value=args.ffn_hidden_size, precision=0) | |
| padded_vocab_size = gr.Number(label="padded vocab size", value=args.padded_vocab_size, precision=0) | |
| num_attention_heads = gr.Number(label="num attention heads", value=args.num_attention_heads, precision=0) | |
| kv_channels = gr.Number(label="kv channels", value=args.kv_channels, precision=0) | |
| with gr.Group() as custom_group: | |
| gr.Markdown("MoE config:") | |
| with gr.Row(): | |
| num_experts = gr.Number(label="num experts", value=args.num_experts, precision=0) | |
| moe_layer_freq = gr.Number(label="moe layer freq", value=args.moe_layer_freq, precision=0) | |
| moe_router_topk = gr.Number(label="moe router topk", value=args.moe_router_topk, precision=0) | |
| moe_ffn_hidden_size = gr.Number(label="moe ffn hidden size", value=args.moe_ffn_hidden_size, precision=0) | |
| moe_shared_expert_intermediate_size = gr.Number(label="moe shared expert intermediate size", value=args.moe_shared_expert_intermediate_size, precision=0) | |
| with gr.Group() as custom_group: | |
| gr.Markdown("MLA config:") | |
| with gr.Row(): | |
| multi_latent_attention = gr.Dropdown(["True", "False"], | |
| label="multi_latent_attention", | |
| value="True" if args.multi_latent_attention else "False") | |
| q_lora_rank = gr.Number(label="q lora rank", value=args.q_lora_rank, precision=0) | |
| kv_lora_rank = gr.Number(label="kv lora rank", value=args.kv_lora_rank, precision=0) | |
| qk_head_dim = gr.Number(label="qk head dim", value=args.qk_head_dim, precision=0) | |
| v_head_dim = gr.Number(label="v head dim", value=args.v_head_dim, precision=0) | |
| qk_pos_emb_head_dim = gr.Number(label="qk pos emb head dim", value=args.qk_pos_emb_head_dim, precision=0) | |
| with gr.Group() as custom_group: | |
| with gr.Row(): | |
| with gr.Group(): | |
| gr.Markdown("MTP config:") | |
| mtp_num_layers = gr.Number(label="mtp num layers", value=args.mtp_num_layers, precision=0) | |
| with gr.Group(): | |
| gr.Markdown("Data config:") | |
| with gr.Row(): | |
| seq_length = gr.Number(label="seq length", value=args.seq_length, precision=0) | |
| batch_size = gr.Number(label="batch size", value=args.batch_size, precision=0) | |
| with gr.Group(): | |
| gr.Markdown("MFU config:") | |
| with gr.Row(): | |
| iter_elapsed_time = gr.Number(label="iter elapsed time", value=args.iter_elapsed_time, precision=0) | |
| num_p800_cards = gr.Number(label="num p800 cards", value=args.num_p800_cards, precision=0) | |
| # 计算结果显示控件 | |
| with gr.Group() as custom_group: | |
| gr.Markdown("Compute results:") | |
| with gr.Row(): | |
| model_flops = gr.Number(label="model flops", precision=0) | |
| flops_per_token = gr.Number(label="flops per token", precision=0) | |
| # mfu = gr.Number(label="mfu", precision=0) | |
| mfu = gr.Textbox(label="MFU P800 bf16") | |
| # 计算按钮 | |
| btn = gr.Button("Calculate") | |
| btn.click( fn=calculate_mfu_web, | |
| inputs=[is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size, | |
| ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels, | |
| num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size, | |
| multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim, | |
| mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards], | |
| outputs=[model_flops, flops_per_token, mfu] | |
| ) | |
| # 启动 Gradio 应用 | |
| demo.launch() |