autoprogrammer commited on
Commit
f7b44bd
·
verified ·
1 Parent(s): 84b28e1

Update modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +204 -8
modeling_minicpm.py CHANGED
@@ -48,7 +48,10 @@ from transformers.utils import (
48
  replace_return_docstrings,
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
- from .configuration_minicpm import MiniCPMConfig
 
 
 
52
  import re
53
 
54
  try:
@@ -283,9 +286,183 @@ class MiniCPMMLP(nn.Module):
283
  return down_proj
284
 
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  class AddAuxiliaryLoss(torch.autograd.Function):
287
  """
288
- The trick function of adding auxiliary (aux) loss,
289
  which includes the gradient of the aux loss during backpropagation.
290
  """
291
  @staticmethod
@@ -304,7 +481,7 @@ class AddAuxiliaryLoss(torch.autograd.Function):
304
 
305
 
306
  class MiniCPMMoE(nn.Module):
307
- def __init__(self, config):
308
  super().__init__()
309
  self.config = config
310
  self.num_experts = config.num_experts
@@ -314,16 +491,34 @@ class MiniCPMMoE(nn.Module):
314
  )
315
  self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
316
  self.intermediate_size = config.intermediate_size
317
-
 
 
 
318
  def forward(self, hidden_states):
319
  orig_shape = hidden_states.shape
320
  orig_dtype = hidden_states.dtype
321
  hidden_states = hidden_states.view(-1, orig_shape[-1])
322
  token_num = hidden_states.shape[0]
 
 
323
  scores = self.gate(hidden_states)
324
- scores_prob = F.softmax(scores, dim=-1, dtype=torch.float32)
325
- expert_weights, expert_indices = torch.topk(scores_prob, self.num_experts_per_tok, dim=-1)
326
- expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  topk_idx_flat = expert_indices.view(-1)
328
  expert_weights = expert_weights.to(orig_dtype)
329
 
@@ -333,8 +528,9 @@ class MiniCPMMoE(nn.Module):
333
  for i in range(self.num_experts):
334
  y[topk_idx_flat == i] = self.experts[i](hidden_states[topk_idx_flat == i])
335
  y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
336
- y = y.view(*orig_shape)
337
 
 
338
  load = expert_indices.view(-1).bincount(minlength=self.num_experts)
339
  load_mean = load / (token_num * self.num_experts_per_tok)
340
  importance_mean = scores_prob.mean(dim=0)
 
48
  replace_return_docstrings,
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
+ try:
52
+ from .configuration_minicpm import MiniCPMConfig
53
+ except ImportError:
54
+ from configuration_minicpm import MiniCPMConfig
55
  import re
56
 
57
  try:
 
286
  return down_proj
287
 
288
 
289
+ # ============================================================================
290
+ # SparseMixer v2 Routing Implementation
291
+ # Based on https://github.com/fairinternal/SparseMixer
292
+ # ============================================================================
293
+
294
+ class SparseMixerCore(torch.autograd.Function):
295
+ """
296
+ Custom autograd function for SparseMixer v2 core operation.
297
+ Implements Heun's third-order method for gradient computation.
298
+ """
299
+ @staticmethod
300
+ def forward(
301
+ ctx,
302
+ scores: torch.Tensor,
303
+ multiplier: torch.Tensor,
304
+ selected_experts: torch.Tensor,
305
+ masked_gates: torch.Tensor,
306
+ mask_for_one: torch.Tensor,
307
+ ):
308
+ ctx.save_for_backward(multiplier, selected_experts, masked_gates)
309
+ return multiplier * mask_for_one
310
+
311
+ @staticmethod
312
+ def backward(
313
+ ctx,
314
+ grad_at_output: torch.Tensor,
315
+ ):
316
+ multiplier, selected_experts, masked_gates = ctx.saved_tensors
317
+
318
+ grad_at_output = grad_at_output * multiplier
319
+
320
+ grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
321
+ grad_at_scores_expaned.scatter_add_(
322
+ dim=-1,
323
+ index=selected_experts,
324
+ src=grad_at_output,
325
+ )
326
+
327
+ return (
328
+ grad_at_scores_expaned,
329
+ None,
330
+ None,
331
+ None,
332
+ None,
333
+ )
334
+
335
+
336
+ def select_single_expert(
337
+ scores: torch.Tensor,
338
+ masked_scores: torch.Tensor,
339
+ jitter_eps: float,
340
+ training: bool,
341
+ ):
342
+ """
343
+ Select a single expert using SparseMixer v2 logic.
344
+
345
+ Args:
346
+ scores: Original scores (for threshold computation)
347
+ masked_scores: Scores with already-selected experts masked out
348
+ jitter_eps: Jitter epsilon for sparsity mask
349
+ training: Whether in training mode
350
+
351
+ Returns:
352
+ multiplier: Weight for the selected expert
353
+ selected_expert: Index of selected expert
354
+ """
355
+ with torch.no_grad():
356
+ # Compute mask for sparsity
357
+ mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
358
+ factor = scores.abs().clamp(min=mask_logits_threshold)
359
+ mask_logits_threshold = (
360
+ (mask_logits_threshold - scores) / factor
361
+ ) > (2 * jitter_eps)
362
+
363
+ # Apply mask
364
+ masked_gates = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
365
+
366
+ if training:
367
+ # Gumbel sampling for robustness
368
+ selected_expert = (
369
+ masked_gates - torch.empty_like(
370
+ masked_gates,
371
+ memory_format=torch.legacy_contiguous_format
372
+ ).exponential_().log()
373
+ ).max(dim=-1)[1].unsqueeze(-1)
374
+ else:
375
+ selected_expert = max_ind
376
+
377
+ # Compute scores for gradients
378
+ masked_gates = torch.softmax(masked_gates, dim=-1)
379
+
380
+ # Compute midpoint mask using Heun's third-order method
381
+ max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
382
+ mask_for_one = torch.logical_or(
383
+ selected_expert == max_ind,
384
+ torch.rand_like(max_scores) > 0.75 # f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
385
+ )
386
+ # Map: 1 -> 1.0 & 0 -> 1./3
387
+ mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
388
+
389
+ # Get multiplier
390
+ multiplier_o = masked_gates.gather(dim=-1, index=selected_expert)
391
+ multiplier = SparseMixerCore.apply(
392
+ scores,
393
+ multiplier_o,
394
+ selected_expert,
395
+ masked_gates,
396
+ mask_for_one,
397
+ )
398
+
399
+ return multiplier, selected_expert
400
+
401
+
402
+ def sparsemixer_topk_routing(
403
+ scores: torch.Tensor,
404
+ top_k: int,
405
+ jitter_eps: float,
406
+ training: bool
407
+ ):
408
+ """
409
+ SparseMixer v2 routing extended to arbitrary top-k.
410
+
411
+ Args:
412
+ scores: Router logits of shape (batch_size, num_experts)
413
+ top_k: Number of experts to select
414
+ jitter_eps: Jitter epsilon for sparsity control
415
+ training: Whether in training mode
416
+
417
+ Returns:
418
+ multiplier: Weights for selected experts, shape (batch_size, top_k)
419
+ original_gates: Original softmax gates, shape (batch_size, num_experts)
420
+ selected_experts: Indices of selected experts, shape (batch_size, top_k)
421
+ """
422
+ original_gates = torch.softmax(scores, dim=-1)
423
+
424
+ all_multipliers = []
425
+ all_selected_experts = []
426
+
427
+ # Start with unmasked scores
428
+ masked_scores = scores.clone()
429
+
430
+ for k in range(top_k):
431
+ # Select k-th expert
432
+ multiplier, selected_expert = select_single_expert(
433
+ scores=scores,
434
+ masked_scores=masked_scores,
435
+ jitter_eps=jitter_eps,
436
+ training=training,
437
+ )
438
+
439
+ all_multipliers.append(multiplier)
440
+ all_selected_experts.append(selected_expert)
441
+
442
+ # Mask out the selected expert for next iteration
443
+ if k < top_k - 1: # Don't need to mask on last iteration
444
+ masked_scores = torch.scatter(
445
+ masked_scores,
446
+ -1,
447
+ selected_expert,
448
+ float('-inf'),
449
+ )
450
+
451
+ # Concatenate all results
452
+ multiplier = torch.cat(all_multipliers, dim=-1) # (batch_size, top_k)
453
+ selected_experts = torch.cat(all_selected_experts, dim=-1) # (batch_size, top_k)
454
+
455
+ return multiplier, original_gates, selected_experts
456
+
457
+
458
+ # ============================================================================
459
+ # End of SparseMixer v2 Implementation
460
+ # ============================================================================
461
+
462
+
463
  class AddAuxiliaryLoss(torch.autograd.Function):
464
  """
465
+ The trick function of adding auxiliary (aux) loss,
466
  which includes the gradient of the aux loss during backpropagation.
467
  """
468
  @staticmethod
 
481
 
482
 
483
  class MiniCPMMoE(nn.Module):
484
+ def __init__(self, config, jitter_eps=0.1):
485
  super().__init__()
486
  self.config = config
487
  self.num_experts = config.num_experts
 
491
  )
492
  self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
493
  self.intermediate_size = config.intermediate_size
494
+
495
+ # SparseMixer specific parameter
496
+ self.jitter_eps = jitter_eps
497
+
498
  def forward(self, hidden_states):
499
  orig_shape = hidden_states.shape
500
  orig_dtype = hidden_states.dtype
501
  hidden_states = hidden_states.view(-1, orig_shape[-1])
502
  token_num = hidden_states.shape[0]
503
+
504
+ # Compute router logits
505
  scores = self.gate(hidden_states)
506
+
507
+ # ===== SparseMixer v2 Routing =====
508
+ # Use SparseMixer v2 routing for expert selection
509
+ expert_weights, scores_prob, expert_indices = sparsemixer_topk_routing(
510
+ scores=scores,
511
+ top_k=self.num_experts_per_tok,
512
+ jitter_eps=self.jitter_eps,
513
+ training=self.training
514
+ )
515
+ # expert_weights: (token_num, top_k) - SparseMixer weights
516
+ # scores_prob: (token_num, num_experts) - Original softmax for loss computation
517
+ # expert_indices: (token_num, top_k) - Selected expert indices
518
+
519
+ # Normalize weights if needed (SparseMixer already provides normalized weights)
520
+ # expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
521
+
522
  topk_idx_flat = expert_indices.view(-1)
523
  expert_weights = expert_weights.to(orig_dtype)
524
 
 
528
  for i in range(self.num_experts):
529
  y[topk_idx_flat == i] = self.experts[i](hidden_states[topk_idx_flat == i])
530
  y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
531
+ y = y.view(*orig_shape)
532
 
533
+ # Load balancing loss (using original softmax probabilities)
534
  load = expert_indices.view(-1).bincount(minlength=self.num_experts)
535
  load_mean = load / (token_num * self.num_experts_per_tok)
536
  importance_mean = scores_prob.mean(dim=0)