Update modeling_minicpm.py
Browse files- modeling_minicpm.py +204 -8
modeling_minicpm.py
CHANGED
|
@@ -48,7 +48,10 @@ from transformers.utils import (
|
|
| 48 |
replace_return_docstrings,
|
| 49 |
)
|
| 50 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
import re
|
| 53 |
|
| 54 |
try:
|
|
@@ -283,9 +286,183 @@ class MiniCPMMLP(nn.Module):
|
|
| 283 |
return down_proj
|
| 284 |
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
class AddAuxiliaryLoss(torch.autograd.Function):
|
| 287 |
"""
|
| 288 |
-
The trick function of adding auxiliary (aux) loss,
|
| 289 |
which includes the gradient of the aux loss during backpropagation.
|
| 290 |
"""
|
| 291 |
@staticmethod
|
|
@@ -304,7 +481,7 @@ class AddAuxiliaryLoss(torch.autograd.Function):
|
|
| 304 |
|
| 305 |
|
| 306 |
class MiniCPMMoE(nn.Module):
|
| 307 |
-
def __init__(self, config):
|
| 308 |
super().__init__()
|
| 309 |
self.config = config
|
| 310 |
self.num_experts = config.num_experts
|
|
@@ -314,16 +491,34 @@ class MiniCPMMoE(nn.Module):
|
|
| 314 |
)
|
| 315 |
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 316 |
self.intermediate_size = config.intermediate_size
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
| 318 |
def forward(self, hidden_states):
|
| 319 |
orig_shape = hidden_states.shape
|
| 320 |
orig_dtype = hidden_states.dtype
|
| 321 |
hidden_states = hidden_states.view(-1, orig_shape[-1])
|
| 322 |
token_num = hidden_states.shape[0]
|
|
|
|
|
|
|
| 323 |
scores = self.gate(hidden_states)
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
topk_idx_flat = expert_indices.view(-1)
|
| 328 |
expert_weights = expert_weights.to(orig_dtype)
|
| 329 |
|
|
@@ -333,8 +528,9 @@ class MiniCPMMoE(nn.Module):
|
|
| 333 |
for i in range(self.num_experts):
|
| 334 |
y[topk_idx_flat == i] = self.experts[i](hidden_states[topk_idx_flat == i])
|
| 335 |
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
|
| 336 |
-
y =
|
| 337 |
|
|
|
|
| 338 |
load = expert_indices.view(-1).bincount(minlength=self.num_experts)
|
| 339 |
load_mean = load / (token_num * self.num_experts_per_tok)
|
| 340 |
importance_mean = scores_prob.mean(dim=0)
|
|
|
|
| 48 |
replace_return_docstrings,
|
| 49 |
)
|
| 50 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 51 |
+
try:
|
| 52 |
+
from .configuration_minicpm import MiniCPMConfig
|
| 53 |
+
except ImportError:
|
| 54 |
+
from configuration_minicpm import MiniCPMConfig
|
| 55 |
import re
|
| 56 |
|
| 57 |
try:
|
|
|
|
| 286 |
return down_proj
|
| 287 |
|
| 288 |
|
| 289 |
+
# ============================================================================
|
| 290 |
+
# SparseMixer v2 Routing Implementation
|
| 291 |
+
# Based on https://github.com/fairinternal/SparseMixer
|
| 292 |
+
# ============================================================================
|
| 293 |
+
|
| 294 |
+
class SparseMixerCore(torch.autograd.Function):
|
| 295 |
+
"""
|
| 296 |
+
Custom autograd function for SparseMixer v2 core operation.
|
| 297 |
+
Implements Heun's third-order method for gradient computation.
|
| 298 |
+
"""
|
| 299 |
+
@staticmethod
|
| 300 |
+
def forward(
|
| 301 |
+
ctx,
|
| 302 |
+
scores: torch.Tensor,
|
| 303 |
+
multiplier: torch.Tensor,
|
| 304 |
+
selected_experts: torch.Tensor,
|
| 305 |
+
masked_gates: torch.Tensor,
|
| 306 |
+
mask_for_one: torch.Tensor,
|
| 307 |
+
):
|
| 308 |
+
ctx.save_for_backward(multiplier, selected_experts, masked_gates)
|
| 309 |
+
return multiplier * mask_for_one
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def backward(
|
| 313 |
+
ctx,
|
| 314 |
+
grad_at_output: torch.Tensor,
|
| 315 |
+
):
|
| 316 |
+
multiplier, selected_experts, masked_gates = ctx.saved_tensors
|
| 317 |
+
|
| 318 |
+
grad_at_output = grad_at_output * multiplier
|
| 319 |
+
|
| 320 |
+
grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
|
| 321 |
+
grad_at_scores_expaned.scatter_add_(
|
| 322 |
+
dim=-1,
|
| 323 |
+
index=selected_experts,
|
| 324 |
+
src=grad_at_output,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return (
|
| 328 |
+
grad_at_scores_expaned,
|
| 329 |
+
None,
|
| 330 |
+
None,
|
| 331 |
+
None,
|
| 332 |
+
None,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def select_single_expert(
|
| 337 |
+
scores: torch.Tensor,
|
| 338 |
+
masked_scores: torch.Tensor,
|
| 339 |
+
jitter_eps: float,
|
| 340 |
+
training: bool,
|
| 341 |
+
):
|
| 342 |
+
"""
|
| 343 |
+
Select a single expert using SparseMixer v2 logic.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
scores: Original scores (for threshold computation)
|
| 347 |
+
masked_scores: Scores with already-selected experts masked out
|
| 348 |
+
jitter_eps: Jitter epsilon for sparsity mask
|
| 349 |
+
training: Whether in training mode
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
multiplier: Weight for the selected expert
|
| 353 |
+
selected_expert: Index of selected expert
|
| 354 |
+
"""
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
# Compute mask for sparsity
|
| 357 |
+
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
|
| 358 |
+
factor = scores.abs().clamp(min=mask_logits_threshold)
|
| 359 |
+
mask_logits_threshold = (
|
| 360 |
+
(mask_logits_threshold - scores) / factor
|
| 361 |
+
) > (2 * jitter_eps)
|
| 362 |
+
|
| 363 |
+
# Apply mask
|
| 364 |
+
masked_gates = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
|
| 365 |
+
|
| 366 |
+
if training:
|
| 367 |
+
# Gumbel sampling for robustness
|
| 368 |
+
selected_expert = (
|
| 369 |
+
masked_gates - torch.empty_like(
|
| 370 |
+
masked_gates,
|
| 371 |
+
memory_format=torch.legacy_contiguous_format
|
| 372 |
+
).exponential_().log()
|
| 373 |
+
).max(dim=-1)[1].unsqueeze(-1)
|
| 374 |
+
else:
|
| 375 |
+
selected_expert = max_ind
|
| 376 |
+
|
| 377 |
+
# Compute scores for gradients
|
| 378 |
+
masked_gates = torch.softmax(masked_gates, dim=-1)
|
| 379 |
+
|
| 380 |
+
# Compute midpoint mask using Heun's third-order method
|
| 381 |
+
max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
|
| 382 |
+
mask_for_one = torch.logical_or(
|
| 383 |
+
selected_expert == max_ind,
|
| 384 |
+
torch.rand_like(max_scores) > 0.75 # f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
|
| 385 |
+
)
|
| 386 |
+
# Map: 1 -> 1.0 & 0 -> 1./3
|
| 387 |
+
mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
|
| 388 |
+
|
| 389 |
+
# Get multiplier
|
| 390 |
+
multiplier_o = masked_gates.gather(dim=-1, index=selected_expert)
|
| 391 |
+
multiplier = SparseMixerCore.apply(
|
| 392 |
+
scores,
|
| 393 |
+
multiplier_o,
|
| 394 |
+
selected_expert,
|
| 395 |
+
masked_gates,
|
| 396 |
+
mask_for_one,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
return multiplier, selected_expert
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def sparsemixer_topk_routing(
|
| 403 |
+
scores: torch.Tensor,
|
| 404 |
+
top_k: int,
|
| 405 |
+
jitter_eps: float,
|
| 406 |
+
training: bool
|
| 407 |
+
):
|
| 408 |
+
"""
|
| 409 |
+
SparseMixer v2 routing extended to arbitrary top-k.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
scores: Router logits of shape (batch_size, num_experts)
|
| 413 |
+
top_k: Number of experts to select
|
| 414 |
+
jitter_eps: Jitter epsilon for sparsity control
|
| 415 |
+
training: Whether in training mode
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
multiplier: Weights for selected experts, shape (batch_size, top_k)
|
| 419 |
+
original_gates: Original softmax gates, shape (batch_size, num_experts)
|
| 420 |
+
selected_experts: Indices of selected experts, shape (batch_size, top_k)
|
| 421 |
+
"""
|
| 422 |
+
original_gates = torch.softmax(scores, dim=-1)
|
| 423 |
+
|
| 424 |
+
all_multipliers = []
|
| 425 |
+
all_selected_experts = []
|
| 426 |
+
|
| 427 |
+
# Start with unmasked scores
|
| 428 |
+
masked_scores = scores.clone()
|
| 429 |
+
|
| 430 |
+
for k in range(top_k):
|
| 431 |
+
# Select k-th expert
|
| 432 |
+
multiplier, selected_expert = select_single_expert(
|
| 433 |
+
scores=scores,
|
| 434 |
+
masked_scores=masked_scores,
|
| 435 |
+
jitter_eps=jitter_eps,
|
| 436 |
+
training=training,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
all_multipliers.append(multiplier)
|
| 440 |
+
all_selected_experts.append(selected_expert)
|
| 441 |
+
|
| 442 |
+
# Mask out the selected expert for next iteration
|
| 443 |
+
if k < top_k - 1: # Don't need to mask on last iteration
|
| 444 |
+
masked_scores = torch.scatter(
|
| 445 |
+
masked_scores,
|
| 446 |
+
-1,
|
| 447 |
+
selected_expert,
|
| 448 |
+
float('-inf'),
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Concatenate all results
|
| 452 |
+
multiplier = torch.cat(all_multipliers, dim=-1) # (batch_size, top_k)
|
| 453 |
+
selected_experts = torch.cat(all_selected_experts, dim=-1) # (batch_size, top_k)
|
| 454 |
+
|
| 455 |
+
return multiplier, original_gates, selected_experts
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# ============================================================================
|
| 459 |
+
# End of SparseMixer v2 Implementation
|
| 460 |
+
# ============================================================================
|
| 461 |
+
|
| 462 |
+
|
| 463 |
class AddAuxiliaryLoss(torch.autograd.Function):
|
| 464 |
"""
|
| 465 |
+
The trick function of adding auxiliary (aux) loss,
|
| 466 |
which includes the gradient of the aux loss during backpropagation.
|
| 467 |
"""
|
| 468 |
@staticmethod
|
|
|
|
| 481 |
|
| 482 |
|
| 483 |
class MiniCPMMoE(nn.Module):
|
| 484 |
+
def __init__(self, config, jitter_eps=0.1):
|
| 485 |
super().__init__()
|
| 486 |
self.config = config
|
| 487 |
self.num_experts = config.num_experts
|
|
|
|
| 491 |
)
|
| 492 |
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 493 |
self.intermediate_size = config.intermediate_size
|
| 494 |
+
|
| 495 |
+
# SparseMixer specific parameter
|
| 496 |
+
self.jitter_eps = jitter_eps
|
| 497 |
+
|
| 498 |
def forward(self, hidden_states):
|
| 499 |
orig_shape = hidden_states.shape
|
| 500 |
orig_dtype = hidden_states.dtype
|
| 501 |
hidden_states = hidden_states.view(-1, orig_shape[-1])
|
| 502 |
token_num = hidden_states.shape[0]
|
| 503 |
+
|
| 504 |
+
# Compute router logits
|
| 505 |
scores = self.gate(hidden_states)
|
| 506 |
+
|
| 507 |
+
# ===== SparseMixer v2 Routing =====
|
| 508 |
+
# Use SparseMixer v2 routing for expert selection
|
| 509 |
+
expert_weights, scores_prob, expert_indices = sparsemixer_topk_routing(
|
| 510 |
+
scores=scores,
|
| 511 |
+
top_k=self.num_experts_per_tok,
|
| 512 |
+
jitter_eps=self.jitter_eps,
|
| 513 |
+
training=self.training
|
| 514 |
+
)
|
| 515 |
+
# expert_weights: (token_num, top_k) - SparseMixer weights
|
| 516 |
+
# scores_prob: (token_num, num_experts) - Original softmax for loss computation
|
| 517 |
+
# expert_indices: (token_num, top_k) - Selected expert indices
|
| 518 |
+
|
| 519 |
+
# Normalize weights if needed (SparseMixer already provides normalized weights)
|
| 520 |
+
# expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
|
| 521 |
+
|
| 522 |
topk_idx_flat = expert_indices.view(-1)
|
| 523 |
expert_weights = expert_weights.to(orig_dtype)
|
| 524 |
|
|
|
|
| 528 |
for i in range(self.num_experts):
|
| 529 |
y[topk_idx_flat == i] = self.experts[i](hidden_states[topk_idx_flat == i])
|
| 530 |
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
|
| 531 |
+
y = y.view(*orig_shape)
|
| 532 |
|
| 533 |
+
# Load balancing loss (using original softmax probabilities)
|
| 534 |
load = expert_indices.view(-1).bincount(minlength=self.num_experts)
|
| 535 |
load_mean = load / (token_num * self.num_experts_per_tok)
|
| 536 |
importance_mean = scores_prob.mean(dim=0)
|