ffurfaro commited on
Commit
f4eeddd
·
verified ·
1 Parent(s): 8d78a17

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -6,36 +6,60 @@ tags:
6
  - tptt
7
  - peft
8
  - trust_remote_code
 
9
  base_model: apple/OpenELM-1_1B
10
  datasets:
11
  - yahma/alpaca-cleaned
12
  ---
13
 
14
- # Titans-OpenELM-1_1B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  Titanesque version of `apple/OpenELM-1_1B` with parallel linearized attention (TPTT 😊) and PEFT.
17
 
18
- The model was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
19
 
20
- ## Model Details
21
 
22
- - **Architecture:** TpttModel
23
- - **Base model:** apple/OpenELM-1_1B
24
- - **LiZA config:** operator=delta_rule, mag=0.5
25
- - **LoRA config:** r=8, alpha=16, dropout=0.05
26
- - **torch_dtype:** bfloat16
27
 
28
- ## Usage
 
 
 
 
 
 
 
 
29
 
 
30
 
31
  ```python
32
  from transformers import AutoModelForCausalLM, AutoTokenizer
33
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
- "ffurfaro/Titans-OpenELM-1_1B",
 
36
  trust_remote_code=True
37
  )
38
- tokenizer = AutoTokenizer.from_pretrained("ffurfaro/Titans-OpenELM-1_1B")
39
 
40
  prompt = "Your prompt here"
41
  inputs = tokenizer(prompt, return_tensors="pt")
@@ -44,26 +68,6 @@ print(tokenizer.decode(outputs, skip_special_tokens=True))
44
 
45
  ```
46
 
47
- ## Training
48
-
49
- - **Dataset:** yahma/alpaca-cleaned
50
- - **Platform:** Kaggle
51
- - **Hardware:** NVIDIA 2xT4
52
- - **Batch size:** 3
53
- - **Epochs:** 5.0
54
- - **Learning rate (final):** 1.1904761904761906e-06
55
- - **Loss (final):** 1.3188
56
- - **Training runtime:** 1651.0658 sec
57
- - **Samples per second:** 1.514
58
- - **Steps per second:** 0.254
59
- - **Total FLOPs:** 5852956262400000.0
60
- - **Gradient norm (final):** 0.7039350271224976
61
-
62
- ## Evaluation
63
-
64
- - **Metrics:** Training loss only (no eval yet, table soon : PiQA, ARC, Hella, Wino, GSM8K, MMLU)
65
- - **Results:** Final training loss: 1.3188
66
-
67
 
68
  ## Citation & Contact
69
 
 
6
  - tptt
7
  - peft
8
  - trust_remote_code
9
+ pipeline_tag: text-generation
10
  base_model: apple/OpenELM-1_1B
11
  datasets:
12
  - yahma/alpaca-cleaned
13
  ---
14
 
15
+ # Titanesque-OpenELM-1_1B
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
 
35
  Titanesque version of `apple/OpenELM-1_1B` with parallel linearized attention (TPTT 😊) and PEFT.
36
 
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
 
 
39
 
40
+ ## Model list
 
 
 
 
41
 
42
+ Classic model parameter with LiZA injection :
43
+
44
+ | Subfolder | Max Self Attn Length | Mag Weight | Cross Gate | Max Chunk Size | Bidirectional | LoRA | Description |
45
+ |-------------------------------|----------------------|------------|------------|----------------|---------------|------|-------------------------------------------------------|
46
+ | delta_rule | 8192 (default) | 0.5 | False | 64 | False | Yes | Parallel linearized attention with delta_rule operator|
47
+ | delta_rule_gelu | 8192 (default) | 0.5 | False | 64 | False | Yes | Non-linear operator with gelu activation |
48
+ | delta_product | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with derivative trick |
49
+ | delta_product_r | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with rotative trick |
50
+ | delta_product_c | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with combined trick |
51
 
52
+ ## Usage
53
 
54
  ```python
55
  from transformers import AutoModelForCausalLM, AutoTokenizer
56
 
57
  model = AutoModelForCausalLM.from_pretrained(
58
+ "ffurfaro/Titanesque-OpenELM-1_1B",
59
+ subfolder="tptt_subfolder", # see in repo tree
60
  trust_remote_code=True
61
  )
62
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/apple/OpenELM-1_1B")
63
 
64
  prompt = "Your prompt here"
65
  inputs = tokenizer(prompt, return_tensors="pt")
 
68
 
69
  ```
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  ## Citation & Contact
73
 
__init__.py CHANGED
@@ -1,19 +1,26 @@
1
- __version__ = "0.1.0"
 
 
2
 
3
- from .configuration_tptt import TpttConfig, generate_model_card
4
- from .modeling_tptt import (AttentionOperator, LCache, LiZAttention, TpttModel,
5
- get_tptt_model)
6
- from .pipeline_tptt import TpttPipeline
7
- from .train_tptt import AdjustMaGWeightCallback
 
8
 
9
  __all__ = [
10
  "TpttConfig",
11
  "TpttModel",
12
- "TpttPipeline",
13
  "get_tptt_model",
14
- "AdjustMaGWeightCallback",
 
15
  "LCache",
16
- "AttentionOperator",
17
  "LiZAttention",
18
  "generate_model_card",
 
 
 
 
19
  ]
 
1
+ """
2
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
3
+ """
4
 
5
+ from .configuration_tptt import (TpttConfig, generate_model_card,
6
+ parse_mode_name)
7
+ from .modeling_tptt import (LCache, LinearAttention, LinearAttentionOp,
8
+ LiZAttention, TpttModel, get_tptt_model,
9
+ load_tptt_safetensors, save_tptt_safetensors)
10
+ from .train_tptt import LiZACallback, SaveBestModelCallback
11
 
12
  __all__ = [
13
  "TpttConfig",
14
  "TpttModel",
 
15
  "get_tptt_model",
16
+ "LiZACallback",
17
+ "SaveBestModelCallback",
18
  "LCache",
19
+ "LinearAttentionOp",
20
  "LiZAttention",
21
  "generate_model_card",
22
+ "LinearAttention",
23
+ "parse_mode_name",
24
+ "load_tptt_safetensors",
25
+ "save_tptt_safetensors",
26
  ]
configuration_tptt.py CHANGED
@@ -1,23 +1,33 @@
 
1
  """
2
  Author : Fabien FURFARO
3
  """
4
 
 
5
  import os
6
  import re
7
- from typing import List, Optional, Union
 
8
 
 
 
9
  from transformers import AutoConfig, PretrainedConfig
10
 
 
 
 
 
 
11
 
12
  def convert_sets_to_lists(obj):
 
13
  if isinstance(obj, set):
14
  return list(obj)
15
- elif isinstance(obj, dict):
16
  return {k: convert_sets_to_lists(v) for k, v in obj.items()}
17
- elif isinstance(obj, (list, tuple)):
18
  return [convert_sets_to_lists(x) for x in obj]
19
- else:
20
- return obj
21
 
22
 
23
  class TpttConfig(PretrainedConfig):
@@ -33,17 +43,73 @@ class TpttConfig(PretrainedConfig):
33
  }
34
  architectures = ["TpttModel"]
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def __init__(
37
  self,
38
  base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
39
  base_model_name: str = "meta-llama/Llama-3.2-1B",
 
40
  name_or_path: Optional[str] = None,
 
41
  target_modules_names: Optional[List[str]] = None,
42
  operator_mode: str = "delta_rule",
43
- max_self_attn_length: int = 8192,
 
 
 
 
44
  mag_weight: float = 0.5, # if 1.0, use only linear operator
45
- max_chunk_size: int = 64,
 
 
46
  lora_config: Optional[dict] = None, # only serialized accepted
 
 
 
47
  **kwargs,
48
  ):
49
  # If base_model_config is provided, load it and merge with this config
@@ -60,11 +126,16 @@ class TpttConfig(PretrainedConfig):
60
  setattr(self, k, v)
61
 
62
  self.base_model_name = base_model_name
63
- self._name_or_path = (
64
- name_or_path
65
- if name_or_path is not None
66
- else "Titans-" + base_model_name.split("/", 1)[1]
67
- )
 
 
 
 
 
68
 
69
  self.target_modules_names = target_modules_names or [
70
  "attn",
@@ -72,9 +143,28 @@ class TpttConfig(PretrainedConfig):
72
  "attention",
73
  ]
74
  self.operator_mode = operator_mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  self.mag_weight = mag_weight
 
76
  self.max_chunk_size = max_chunk_size
77
  self.max_self_attn_length = max_self_attn_length
 
 
 
78
 
79
  self.lora_config = lora_config
80
  if lora_config is not None:
@@ -82,56 +172,147 @@ class TpttConfig(PretrainedConfig):
82
  self.lora_config["peft_type"] = self.lora_config["peft_type"].value
83
  self.lora_config = convert_sets_to_lists(self.lora_config)
84
 
 
 
 
 
 
 
85
  super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
86
  # Copy class attributes to instance for serialization (save dict)
87
  self.model_type = self.__class__.model_type
88
  self.auto_map = self.__class__.auto_map
89
  self.architectures = self.__class__.architectures
 
 
 
 
 
 
 
 
 
 
90
 
91
 
92
  TpttConfig.register_for_auto_class()
93
 
94
 
95
- def extract_template_variables(template):
96
- return set(re.findall(r"\{([^{}]+)\}", template))
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
98
 
99
- def generate_model_card(path: str, config, **kwargs):
100
- """Generate model card from template and training metadata."""
101
- template_path = os.path.join(os.path.dirname(__file__), "model_card_template.md")
102
- with open(template_path, "r", encoding="utf-8") as f:
103
- template = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Flatten config
106
- def flatten_config(config):
107
- result = {}
108
- if hasattr(config, "__dict__"):
109
- config = config.__dict__
110
- for k, v in config.items():
111
- if isinstance(v, dict):
112
- for subk, subv in v.items():
113
- result[f"{k}_{subk}"] = subv
114
- else:
115
- result[k] = v
116
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- variables = flatten_config(config)
119
- variables.update(kwargs)
120
- variables["model_id"] = os.path.basename(path)
121
 
122
- # Extract variables from template
123
- template_vars = extract_template_variables(template)
 
 
 
124
 
125
- # Add default values for missing variables
126
- for var in template_vars:
127
- if var not in variables:
128
- variables[var] = "N/A"
129
 
130
- # Handle list conversion (optional but useful)
131
- for k, v in variables.items():
132
- if isinstance(v, list):
133
- variables[k] = ", ".join(map(str, v))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- model_card_content = template.format(**variables)
136
- with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
137
- f.write(model_card_content)
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
  """
3
  Author : Fabien FURFARO
4
  """
5
 
6
+ import logging
7
  import os
8
  import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
 
12
+ import psutil
13
+ import torch
14
  from transformers import AutoConfig, PretrainedConfig
15
 
16
+ logger = logging.getLogger(__name__) # monitoring
17
+
18
+ # Constants
19
+ BYTES_IN_GB = 1024**3
20
+
21
 
22
  def convert_sets_to_lists(obj):
23
+ """Convert sets to list for LoRA serialized config"""
24
  if isinstance(obj, set):
25
  return list(obj)
26
+ if isinstance(obj, dict):
27
  return {k: convert_sets_to_lists(v) for k, v in obj.items()}
28
+ if isinstance(obj, (list, tuple)):
29
  return [convert_sets_to_lists(x) for x in obj]
30
+ return obj
 
31
 
32
 
33
  class TpttConfig(PretrainedConfig):
 
43
  }
44
  architectures = ["TpttModel"]
45
 
46
+ RECURRENT_MODES = {
47
+ "delta_rule": {
48
+ "order": 1,
49
+ "gate_type": "k",
50
+ "linear": True,
51
+ "trick": "derivative",
52
+ },
53
+ "delta_rule_v": {
54
+ "order": 1,
55
+ "gate_type": "v",
56
+ "linear": True,
57
+ "trick": "derivative",
58
+ },
59
+ "delta_rule_kv": {
60
+ "order": 1,
61
+ "gate_type": "kv",
62
+ "linear": True,
63
+ "trick": "derivative",
64
+ },
65
+ "delta_rule_gelu": {
66
+ "order": 1,
67
+ "gate_type": "k",
68
+ "linear": False,
69
+ "trick": "derivative",
70
+ },
71
+ "delta_product": {
72
+ "order": 2,
73
+ "gate_type": "k",
74
+ "linear": True,
75
+ "trick": "derivative",
76
+ },
77
+ "delta_product_r": {
78
+ "order": 2,
79
+ "gate_type": "k",
80
+ "linear": True,
81
+ "trick": "rotative",
82
+ },
83
+ "delta_product_c": {
84
+ "order": 2,
85
+ "gate_type": "k",
86
+ "linear": True,
87
+ "trick": "combined",
88
+ },
89
+ } # Tested modes, see parse_mode_name if you want to add more
90
+
91
  def __init__(
92
  self,
93
  base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
94
  base_model_name: str = "meta-llama/Llama-3.2-1B",
95
+ base_model_subfolder: Optional[str] = None,
96
  name_or_path: Optional[str] = None,
97
+ model_task: str = "causal_lm",
98
  target_modules_names: Optional[List[str]] = None,
99
  operator_mode: str = "delta_rule",
100
+ use_linear_checkpoint: Optional[bool] = None,
101
+ max_self_attn_length: Optional[
102
+ int
103
+ ] = None, # unnecessary if SWA, else, standards 8192
104
+ base_scale_attn: bool = False,
105
  mag_weight: float = 0.5, # if 1.0, use only linear operator
106
+ cross_gate: bool = False, # unlinear mixing strategy
107
+ max_chunk_size: int = 64, # 128 if adaptive chunking (longest)
108
+ linear_precision: Union[str, torch.dtype] = "float32",
109
  lora_config: Optional[dict] = None, # only serialized accepted
110
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
111
+ bidirectional: bool = False, # if True, use bidirectional attention
112
+ pooling_config: Optional[Dict[str, Any]] = None,
113
  **kwargs,
114
  ):
115
  # If base_model_config is provided, load it and merge with this config
 
126
  setattr(self, k, v)
127
 
128
  self.base_model_name = base_model_name
129
+ self.base_model_subfolder = base_model_subfolder
130
+ self.model_task = model_task
131
+
132
+ if name_or_path is not None:
133
+ self._name_or_path = name_or_path
134
+ else:
135
+ if "/" in base_model_name:
136
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
137
+ else:
138
+ self._name_or_path = "Titans-" + base_model_name
139
 
140
  self.target_modules_names = target_modules_names or [
141
  "attn",
 
143
  "attention",
144
  ]
145
  self.operator_mode = operator_mode
146
+
147
+ # Detect available memory on accelerator device
148
+ if torch.cuda.is_available():
149
+ _, total_mem = torch.cuda.mem_get_info()
150
+ else:
151
+ total_mem = psutil.virtual_memory().total
152
+ total_mem_gb = total_mem / BYTES_IN_GB
153
+
154
+ self.use_linear_checkpoint = (
155
+ total_mem_gb < 16
156
+ if use_linear_checkpoint is None
157
+ else use_linear_checkpoint
158
+ )
159
+
160
+ self.base_scale_attn = base_scale_attn
161
  self.mag_weight = mag_weight
162
+ self.cross_gate = cross_gate
163
  self.max_chunk_size = max_chunk_size
164
  self.max_self_attn_length = max_self_attn_length
165
+ if isinstance(linear_precision, torch.dtype):
166
+ linear_precision = str(linear_precision).replace("torch.", "")
167
+ self.linear_precision = linear_precision
168
 
169
  self.lora_config = lora_config
170
  if lora_config is not None:
 
172
  self.lora_config["peft_type"] = self.lora_config["peft_type"].value
173
  self.lora_config = convert_sets_to_lists(self.lora_config)
174
 
175
+ self.padding_side = padding_side
176
+ self.bidirectional = bidirectional
177
+ if self.bidirectional:
178
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
179
+ self.pooling_config = pooling_config
180
+
181
  super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
182
  # Copy class attributes to instance for serialization (save dict)
183
  self.model_type = self.__class__.model_type
184
  self.auto_map = self.__class__.auto_map
185
  self.architectures = self.__class__.architectures
186
+ # Padding side configuration if not set
187
+ if self.padding_side is None:
188
+ self.padding_side = "right"
189
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
190
+ # set recurrent configuration from operator mode
191
+ if operator_mode not in self.__class__.RECURRENT_MODES:
192
+ self.recurrent_config = parse_mode_name(operator_mode)
193
+ else:
194
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
195
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
196
 
197
 
198
  TpttConfig.register_for_auto_class()
199
 
200
 
201
+ def parse_mode_name(name: str) -> dict:
202
+ """Parse mode to recurrent config"""
203
+ if name.startswith("delta_product"):
204
+ parts = name.split("_")
205
+ # Prefix is always two words: 'delta' and 'product'
206
+ base_len = 2
207
+ order = 2
208
+ gate_type = "k"
209
+ linear = True
210
+ trick = "derivative"
211
 
212
+ idx = base_len
213
+ # Check for order (immediately after the prefix)
214
+ if len(parts) > idx and parts[idx].isdigit():
215
+ order = int(parts[idx])
216
+ idx += 1
217
 
218
+ remaining = parts[idx:]
219
+ # Trick (r/c) is always at the far right if present
220
+ if remaining and remaining[-1] in ("r", "c"):
221
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
222
+ remaining = remaining[:-1]
223
+ # 'gelu' comes just before the trick if present
224
+ if remaining and remaining[-1] == "gelu":
225
+ linear = False
226
+ remaining = remaining[:-1]
227
+ # If anything remains, it's the gate_type
228
+ if remaining:
229
+ gate_type = "_".join(remaining)
230
+ return {
231
+ "order": order,
232
+ "gate_type": gate_type,
233
+ "linear": linear,
234
+ "trick": trick,
235
+ }
236
 
237
+ # delta_rule[_gate][_gelu]
238
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
239
+ if m:
240
+ return {
241
+ "order": 1,
242
+ "gate_type": m.group(1) if m.group(1) else "k",
243
+ "linear": not bool(m.group(2)),
244
+ "trick": "derivative",
245
+ }
246
+ raise ValueError(f"Unknown mode: {name}")
247
+
248
+
249
+ def get_mode_name(
250
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
251
+ ) -> str:
252
+ """Get recurrent mode name from parameter"""
253
+ base = (
254
+ "delta_rule"
255
+ if order == 1
256
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
257
+ )
258
+ parts = []
259
+ if gate_type != "k":
260
+ parts.append(gate_type)
261
+ if not linear:
262
+ parts.append("gelu")
263
+ if order >= 2 and trick != "derivative":
264
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
265
+ return base + (("_" + "_".join(parts)) if parts else "")
266
 
 
 
 
267
 
268
+ def render_template(template_path: str, variables: dict) -> str:
269
+ """Load and render a Jinja2 template from any file path."""
270
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
271
+ template = env.get_template(os.path.basename(template_path))
272
+ return template.render(**variables)
273
 
 
 
 
 
274
 
275
+ def write_model_card(output_path: str, content: str):
276
+ """Write the generated content into README.md."""
277
+ os.makedirs(output_path, exist_ok=True)
278
+ readme_path = os.path.join(output_path, "README.md")
279
+ with open(readme_path, "w", encoding="utf-8") as f:
280
+ f.write(content)
281
+
282
+
283
+ def generate_model_card(
284
+ output_path: str,
285
+ config: Union[dict, object],
286
+ template: Optional[
287
+ str
288
+ ], # can be "model_card" OR an absolute/relative path to a .md file
289
+ extra_variables: Optional[Dict] = None,
290
+ ):
291
+ """
292
+ Generate a README.md file from a Jinja2 template and a configuration.
293
+
294
+ - template can be either:
295
+ * a full path to a template file
296
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
297
+ """
298
+ if template is None:
299
+ template = "model_card_template" # default template name
300
+ # Locate the template
301
+ if os.path.exists(template): # direct file path provided
302
+ template_path = template
303
+ else:
304
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
305
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
306
+
307
+ if not os.path.exists(template_path):
308
+ raise FileNotFoundError(f"Template not found: {template_path}")
309
+
310
+ variables = {
311
+ "model_id": os.path.basename(output_path),
312
+ "config": config,
313
+ }
314
+ if extra_variables:
315
+ variables.update(extra_variables)
316
 
317
+ content = render_template(template_path, variables)
318
+ write_model_card(output_path, content)
 
lora_delta_product_m0.5_constant/README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ pipeline_tag: text-generation
10
+ base_model: apple/OpenELM-1_1B
11
+ datasets:
12
+ - yahma/alpaca-cleaned
13
+ ---
14
+
15
+ # lora_delta_product_m0.5_constant
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
+
35
+ Titanesque version of `apple/OpenELM-1_1B` with parallel linearized attention (TPTT 😊) and PEFT.
36
+
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
+
39
+
40
+ ## Model Details
41
+
42
+ - **Architecture:** ['TpttModel']
43
+ - **Base model:** apple/OpenELM-1_1B
44
+ - **LiZA config:** operator=delta_product, mag=0.5
45
+ - **LoRA config:** r=8, alpha=16, dropout=0.05
46
+ - **torch_dtype:**
47
+
48
+ ## Usage
49
+
50
+
51
+ ```python
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ "ffurfaro/lora_delta_product_m0.5_constant",
56
+ trust_remote_code=True
57
+ )
58
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/apple/OpenELM-1_1B")
59
+
60
+ prompt = "Your prompt here"
61
+ inputs = tokenizer(prompt, return_tensors="pt")
62
+ outputs = model.generate(**inputs, max_new_tokens=100)
63
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
64
+
65
+ ```
66
+
67
+ > [!IMPORTANT]
68
+ > You must specify the `subfolder` if the repo contains multiple models, see the homepage for details.
69
+
70
+ ## Training
71
+
72
+ - **Dataset:** yahma/alpaca-cleaned
73
+ - **Platform:** Kaggle
74
+ - **Hardware:** 2xT4
75
+ - **Batch size:** 2
76
+ - **Epochs:** 1.0
77
+ - **Learning rate (final):** N/A
78
+ - **Loss (final):** 2.0623671754250386
79
+ - **Training runtime:** 2735.583 sec
80
+ - **Samples per second:** 0.946
81
+ - **Steps per second:** 0.237
82
+ - **Total FLOPs:** 2018060372803584.0
83
+ - **Gradient norm (final):** N/A
84
+
85
+ ## Evaluation
86
+
87
+ - **Metrics:** Training loss only (no eval yet, table soon : PiQA, ARC, Hella, Wino, GSM8K, MMLU)
88
+ - **Results:** Final training loss: 2.0623671754250386
89
+
90
+
91
+ ## Citation & Contact
92
+
93
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
94
+
95
+
96
+ ---
lora_delta_product_m0.5_constant/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42b7ba7ce3507b8bcab8d45188a94e47e1f84bbe215e62aac9101f182faea530
3
+ size 3919584
lora_delta_product_m0.5_constant/config.json ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_fn_name": "swish",
3
+ "architectures": [
4
+ "TpttModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_tptt.TpttConfig",
8
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel"
9
+ },
10
+ "base_model_name": "apple/OpenELM-1_1B",
11
+ "base_model_subfolder": null,
12
+ "base_scale_attn": null,
13
+ "bidirectional": false,
14
+ "cross_gate": false,
15
+ "ffn_dim_divisor": 256,
16
+ "ffn_multipliers": [
17
+ 0.5,
18
+ 0.63,
19
+ 0.76,
20
+ 0.89,
21
+ 1.02,
22
+ 1.15,
23
+ 1.28,
24
+ 1.41,
25
+ 1.54,
26
+ 1.67,
27
+ 1.8,
28
+ 1.93,
29
+ 2.06,
30
+ 2.19,
31
+ 2.31,
32
+ 2.44,
33
+ 2.57,
34
+ 2.7,
35
+ 2.83,
36
+ 2.96,
37
+ 3.09,
38
+ 3.22,
39
+ 3.35,
40
+ 3.48,
41
+ 3.61,
42
+ 3.74,
43
+ 3.87,
44
+ 4.0
45
+ ],
46
+ "ffn_with_glu": true,
47
+ "head_dim": 64,
48
+ "initializer_range": 0.02,
49
+ "linear_precision": "bfloat16",
50
+ "lora_config": {
51
+ "alpha_pattern": {},
52
+ "auto_mapping": null,
53
+ "base_model_name_or_path": null,
54
+ "bias": "none",
55
+ "corda_config": null,
56
+ "eva_config": null,
57
+ "exclude_modules": null,
58
+ "fan_in_fan_out": false,
59
+ "inference_mode": false,
60
+ "init_lora_weights": true,
61
+ "layer_replication": null,
62
+ "layers_pattern": null,
63
+ "layers_to_transform": null,
64
+ "loftq_config": {},
65
+ "lora_alpha": 16,
66
+ "lora_bias": false,
67
+ "lora_dropout": 0.05,
68
+ "megatron_config": null,
69
+ "megatron_core": "megatron.core",
70
+ "modules_to_save": null,
71
+ "peft_type": "LORA",
72
+ "r": 8,
73
+ "rank_pattern": {},
74
+ "revision": null,
75
+ "target_modules": [
76
+ "o_proj",
77
+ "qkv_proj"
78
+ ],
79
+ "task_type": "CAUSAL_LM",
80
+ "trainable_token_indices": null,
81
+ "use_dora": false,
82
+ "use_rslora": false
83
+ },
84
+ "mag_weight": 0.5,
85
+ "max_chunk_size": 64,
86
+ "max_context_length": 2048,
87
+ "max_self_attn_length": null,
88
+ "model_dim": 2048,
89
+ "model_task": "causal_lm",
90
+ "model_type": "tptt",
91
+ "normalization_layer_name": "rms_norm",
92
+ "normalize_qk_projections": true,
93
+ "num_gqa_groups": 4,
94
+ "num_kv_heads": [
95
+ 4,
96
+ 4,
97
+ 4,
98
+ 5,
99
+ 5,
100
+ 5,
101
+ 5,
102
+ 5,
103
+ 5,
104
+ 5,
105
+ 6,
106
+ 6,
107
+ 6,
108
+ 6,
109
+ 6,
110
+ 6,
111
+ 6,
112
+ 6,
113
+ 7,
114
+ 7,
115
+ 7,
116
+ 7,
117
+ 7,
118
+ 7,
119
+ 8,
120
+ 8,
121
+ 8,
122
+ 8
123
+ ],
124
+ "num_query_heads": [
125
+ 16,
126
+ 16,
127
+ 16,
128
+ 20,
129
+ 20,
130
+ 20,
131
+ 20,
132
+ 20,
133
+ 20,
134
+ 20,
135
+ 24,
136
+ 24,
137
+ 24,
138
+ 24,
139
+ 24,
140
+ 24,
141
+ 24,
142
+ 24,
143
+ 28,
144
+ 28,
145
+ 28,
146
+ 28,
147
+ 28,
148
+ 28,
149
+ 32,
150
+ 32,
151
+ 32,
152
+ 32
153
+ ],
154
+ "num_transformer_layers": 28,
155
+ "operator_mode": "delta_product",
156
+ "padding_side": "right",
157
+ "pooling_config": null,
158
+ "qkv_multipliers": [
159
+ 0.5,
160
+ 1.0
161
+ ],
162
+ "recurrent_config": {
163
+ "gate_type": "k",
164
+ "linear": true,
165
+ "order": 2,
166
+ "trick": "derivative"
167
+ },
168
+ "rope_freq_constant": 10000,
169
+ "rope_max_length": 4096,
170
+ "share_input_output_layers": true,
171
+ "target_modules_names": [
172
+ "attn",
173
+ "self_attn",
174
+ "attention"
175
+ ],
176
+ "torch_dtype": "bfloat16",
177
+ "transformers_version": "4.49.0",
178
+ "trust_remote_code": true,
179
+ "use_cache": true,
180
+ "use_linear_checkpoint": false,
181
+ "vocab_size": 32000
182
+ }
lora_delta_product_m0.5_constant/configuration_tptt.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+ """
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
+
12
+ import psutil
13
+ import torch
14
+ from transformers import AutoConfig, PretrainedConfig
15
+
16
+ logger = logging.getLogger(__name__) # monitoring
17
+
18
+ # Constants
19
+ BYTES_IN_GB = 1024**3
20
+
21
+
22
+ def convert_sets_to_lists(obj):
23
+ """Convert sets to list for LoRA serialized config"""
24
+ if isinstance(obj, set):
25
+ return list(obj)
26
+ if isinstance(obj, dict):
27
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
28
+ if isinstance(obj, (list, tuple)):
29
+ return [convert_sets_to_lists(x) for x in obj]
30
+ return obj
31
+
32
+
33
+ class TpttConfig(PretrainedConfig):
34
+ """
35
+ Configuration class for the TPTT model.
36
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
37
+ """
38
+
39
+ model_type = "tptt"
40
+ auto_map = {
41
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
42
+ "AutoConfig": "configuration_tptt.TpttConfig",
43
+ }
44
+ architectures = ["TpttModel"]
45
+
46
+ RECURRENT_MODES = {
47
+ "delta_rule": {
48
+ "order": 1,
49
+ "gate_type": "k",
50
+ "linear": True,
51
+ "trick": "derivative",
52
+ },
53
+ "delta_rule_v": {
54
+ "order": 1,
55
+ "gate_type": "v",
56
+ "linear": True,
57
+ "trick": "derivative",
58
+ },
59
+ "delta_rule_kv": {
60
+ "order": 1,
61
+ "gate_type": "kv",
62
+ "linear": True,
63
+ "trick": "derivative",
64
+ },
65
+ "delta_rule_gelu": {
66
+ "order": 1,
67
+ "gate_type": "k",
68
+ "linear": False,
69
+ "trick": "derivative",
70
+ },
71
+ "delta_product": {
72
+ "order": 2,
73
+ "gate_type": "k",
74
+ "linear": True,
75
+ "trick": "derivative",
76
+ },
77
+ "delta_product_r": {
78
+ "order": 2,
79
+ "gate_type": "k",
80
+ "linear": True,
81
+ "trick": "rotative",
82
+ },
83
+ "delta_product_c": {
84
+ "order": 2,
85
+ "gate_type": "k",
86
+ "linear": True,
87
+ "trick": "combined",
88
+ },
89
+ } # Tested modes, see parse_mode_name if you want to add more
90
+
91
+ def __init__(
92
+ self,
93
+ base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
94
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
95
+ base_model_subfolder: Optional[str] = None,
96
+ name_or_path: Optional[str] = None,
97
+ model_task: str = "causal_lm",
98
+ target_modules_names: Optional[List[str]] = None,
99
+ operator_mode: str = "delta_rule",
100
+ use_linear_checkpoint: Optional[bool] = None,
101
+ max_self_attn_length: Optional[
102
+ int
103
+ ] = None, # unnecessary if SWA, else, standards 8192
104
+ base_scale_attn: bool = False,
105
+ mag_weight: float = 0.5, # if 1.0, use only linear operator
106
+ cross_gate: bool = False, # unlinear mixing strategy
107
+ max_chunk_size: int = 64, # 128 if adaptive chunking (longest)
108
+ linear_precision: Union[str, torch.dtype] = "float32",
109
+ lora_config: Optional[dict] = None, # only serialized accepted
110
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
111
+ bidirectional: bool = False, # if True, use bidirectional attention
112
+ pooling_config: Optional[Dict[str, Any]] = None,
113
+ **kwargs,
114
+ ):
115
+ # If base_model_config is provided, load it and merge with this config
116
+ if base_model_config is not None:
117
+ if isinstance(base_model_config, PretrainedConfig):
118
+ base_model_config = base_model_config.to_dict()
119
+ else:
120
+ # Load config from Hugging Face Hub or a local path
121
+ base_model_config = AutoConfig.from_pretrained(
122
+ base_model_name, **kwargs
123
+ ).to_dict()
124
+ # Merge all backbone fields into this config
125
+ for k, v in base_model_config.items():
126
+ setattr(self, k, v)
127
+
128
+ self.base_model_name = base_model_name
129
+ self.base_model_subfolder = base_model_subfolder
130
+ self.model_task = model_task
131
+
132
+ if name_or_path is not None:
133
+ self._name_or_path = name_or_path
134
+ else:
135
+ if "/" in base_model_name:
136
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
137
+ else:
138
+ self._name_or_path = "Titans-" + base_model_name
139
+
140
+ self.target_modules_names = target_modules_names or [
141
+ "attn",
142
+ "self_attn",
143
+ "attention",
144
+ ]
145
+ self.operator_mode = operator_mode
146
+
147
+ # Detect available memory on accelerator device
148
+ if torch.cuda.is_available():
149
+ _, total_mem = torch.cuda.mem_get_info()
150
+ else:
151
+ total_mem = psutil.virtual_memory().total
152
+ total_mem_gb = total_mem / BYTES_IN_GB
153
+
154
+ self.use_linear_checkpoint = (
155
+ total_mem_gb < 16
156
+ if use_linear_checkpoint is None
157
+ else use_linear_checkpoint
158
+ )
159
+
160
+ self.base_scale_attn = base_scale_attn
161
+ self.mag_weight = mag_weight
162
+ self.cross_gate = cross_gate
163
+ self.max_chunk_size = max_chunk_size
164
+ self.max_self_attn_length = max_self_attn_length
165
+ if isinstance(linear_precision, torch.dtype):
166
+ linear_precision = str(linear_precision).replace("torch.", "")
167
+ self.linear_precision = linear_precision
168
+
169
+ self.lora_config = lora_config
170
+ if lora_config is not None:
171
+ if hasattr(self.lora_config.get("peft_type"), "value"):
172
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
173
+ self.lora_config = convert_sets_to_lists(self.lora_config)
174
+
175
+ self.padding_side = padding_side
176
+ self.bidirectional = bidirectional
177
+ if self.bidirectional:
178
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
179
+ self.pooling_config = pooling_config
180
+
181
+ super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
182
+ # Copy class attributes to instance for serialization (save dict)
183
+ self.model_type = self.__class__.model_type
184
+ self.auto_map = self.__class__.auto_map
185
+ self.architectures = self.__class__.architectures
186
+ # Padding side configuration if not set
187
+ if self.padding_side is None:
188
+ self.padding_side = "right"
189
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
190
+ # set recurrent configuration from operator mode
191
+ if operator_mode not in self.__class__.RECURRENT_MODES:
192
+ self.recurrent_config = parse_mode_name(operator_mode)
193
+ else:
194
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
195
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
196
+
197
+
198
+ TpttConfig.register_for_auto_class()
199
+
200
+
201
+ def parse_mode_name(name: str) -> dict:
202
+ """Parse mode to recurrent config"""
203
+ if name.startswith("delta_product"):
204
+ parts = name.split("_")
205
+ # Prefix is always two words: 'delta' and 'product'
206
+ base_len = 2
207
+ order = 2
208
+ gate_type = "k"
209
+ linear = True
210
+ trick = "derivative"
211
+
212
+ idx = base_len
213
+ # Check for order (immediately after the prefix)
214
+ if len(parts) > idx and parts[idx].isdigit():
215
+ order = int(parts[idx])
216
+ idx += 1
217
+
218
+ remaining = parts[idx:]
219
+ # Trick (r/c) is always at the far right if present
220
+ if remaining and remaining[-1] in ("r", "c"):
221
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
222
+ remaining = remaining[:-1]
223
+ # 'gelu' comes just before the trick if present
224
+ if remaining and remaining[-1] == "gelu":
225
+ linear = False
226
+ remaining = remaining[:-1]
227
+ # If anything remains, it's the gate_type
228
+ if remaining:
229
+ gate_type = "_".join(remaining)
230
+ return {
231
+ "order": order,
232
+ "gate_type": gate_type,
233
+ "linear": linear,
234
+ "trick": trick,
235
+ }
236
+
237
+ # delta_rule[_gate][_gelu]
238
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
239
+ if m:
240
+ return {
241
+ "order": 1,
242
+ "gate_type": m.group(1) if m.group(1) else "k",
243
+ "linear": not bool(m.group(2)),
244
+ "trick": "derivative",
245
+ }
246
+ raise ValueError(f"Unknown mode: {name}")
247
+
248
+
249
+ def get_mode_name(
250
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
251
+ ) -> str:
252
+ """Get recurrent mode name from parameter"""
253
+ base = (
254
+ "delta_rule"
255
+ if order == 1
256
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
257
+ )
258
+ parts = []
259
+ if gate_type != "k":
260
+ parts.append(gate_type)
261
+ if not linear:
262
+ parts.append("gelu")
263
+ if order >= 2 and trick != "derivative":
264
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
265
+ return base + (("_" + "_".join(parts)) if parts else "")
266
+
267
+
268
+ def render_template(template_path: str, variables: dict) -> str:
269
+ """Load and render a Jinja2 template from any file path."""
270
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
271
+ template = env.get_template(os.path.basename(template_path))
272
+ return template.render(**variables)
273
+
274
+
275
+ def write_model_card(output_path: str, content: str):
276
+ """Write the generated content into README.md."""
277
+ os.makedirs(output_path, exist_ok=True)
278
+ readme_path = os.path.join(output_path, "README.md")
279
+ with open(readme_path, "w", encoding="utf-8") as f:
280
+ f.write(content)
281
+
282
+
283
+ def generate_model_card(
284
+ output_path: str,
285
+ config: Union[dict, object],
286
+ template: Optional[
287
+ str
288
+ ], # can be "model_card" OR an absolute/relative path to a .md file
289
+ extra_variables: Optional[Dict] = None,
290
+ ):
291
+ """
292
+ Generate a README.md file from a Jinja2 template and a configuration.
293
+
294
+ - template can be either:
295
+ * a full path to a template file
296
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
297
+ """
298
+ if template is None:
299
+ template = "model_card_template" # default template name
300
+ # Locate the template
301
+ if os.path.exists(template): # direct file path provided
302
+ template_path = template
303
+ else:
304
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
305
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
306
+
307
+ if not os.path.exists(template_path):
308
+ raise FileNotFoundError(f"Template not found: {template_path}")
309
+
310
+ variables = {
311
+ "model_id": os.path.basename(output_path),
312
+ "config": config,
313
+ }
314
+ if extra_variables:
315
+ variables.update(extra_variables)
316
+
317
+ content = render_template(template_path, variables)
318
+ write_model_card(output_path, content)
lora_delta_product_m0.5_constant/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.49.0"
4
+ }
lora_delta_product_m0.5_constant/modeling_tptt.py ADDED
@@ -0,0 +1,1504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
+ """
4
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
+ Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import (
28
+ AutoConfig,
29
+ AutoModel,
30
+ AutoModelForCausalLM,
31
+ DynamicCache,
32
+ PreTrainedModel,
33
+ )
34
+ from transformers.configuration_utils import PretrainedConfig
35
+
36
+ from .configuration_tptt import TpttConfig
37
+
38
+ logger = logging.getLogger(__name__) # monitoring
39
+
40
+
41
+ class LCache:
42
+ """Cache for storing intermediate states of linear attention layers."""
43
+
44
+ def __init__(self):
45
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
46
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
47
+ {}
48
+ ) # recurrent states and qkv buffers
49
+
50
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
51
+ """Retrieve cached state for a given layer, or None if not present"""
52
+ return self.inputs_states.get(layer_idx, None)
53
+
54
+ def update(self, layer_idx: int, **kwargs):
55
+ """Detach all tensors to avoid retaining computation graphs"""
56
+ detached_kwargs = {
57
+ k: v.detach() if isinstance(v, torch.Tensor) else v
58
+ for k, v in kwargs.items()
59
+ }
60
+ # Update or create the state for the specified layer
61
+ if layer_idx in self.inputs_states:
62
+ self.inputs_states[layer_idx].update(detached_kwargs)
63
+ else:
64
+ self.inputs_states[layer_idx] = detached_kwargs
65
+
66
+ def reset(self):
67
+ """Clear all cached states and reset the token counter"""
68
+ self.inputs_states.clear()
69
+
70
+
71
+ class CausalAvgPool1d(nn.Module):
72
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
73
+
74
+ def __init__(
75
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
76
+ ):
77
+ super().__init__()
78
+ self.offsets = offsets
79
+ self.mode = mode
80
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """x: [B, S, F] → [B, S, F → output_size]"""
84
+ x_ = x.transpose(1, 2) # [B, F, S]
85
+ idxs = torch.tensor(self.offsets, device=x.device)
86
+ ksize = idxs.max() - idxs.min() + 1
87
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
88
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
89
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
90
+ pad_left = -idxs.min().item()
91
+ pad_right = (ksize - 1) - pad_left
92
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
93
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
94
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
95
+
96
+
97
+ class LinearAttention(nn.Module):
98
+ """
99
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
100
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ hidden_dim: int,
106
+ num_heads: int,
107
+ head_dim: Optional[int] = None,
108
+ num_key_value_heads: Optional[int] = None,
109
+ num_key_value_groups: Optional[int] = None,
110
+ bias: bool = True,
111
+ dropout: Optional[float] = None,
112
+ linear_precision: torch.dtype = torch.float32,
113
+ padding_side: str = "right",
114
+ shared_attn: bool = False, # shared attention
115
+ layer_idx: int = 0,
116
+ operator_mode: str = "delta_rule",
117
+ use_linear_checkpoint: bool = False,
118
+ recurrent_config: Optional[Dict[str, Any]] = None,
119
+ linear_cache: Optional[LCache] = None,
120
+ max_chunk_size: int = 64,
121
+ bidirectional: bool = False, # not used if causal
122
+ pooling_config: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ super().__init__()
125
+ if pooling_config is None:
126
+ pooling_config = {
127
+ "offsets": (0, 1, 2),
128
+ "mode": "replicate",
129
+ }
130
+ self.hidden_dim = hidden_dim
131
+ self.num_heads = num_heads
132
+ self.head_dim = head_dim or hidden_dim // num_heads
133
+ self.num_key_value_heads = num_key_value_heads or num_heads
134
+ self.num_key_value_groups = num_key_value_groups or (
135
+ num_heads // (num_key_value_heads or num_heads)
136
+ )
137
+ self.scaling = self.head_dim**-0.5
138
+ self.linear_precision = linear_precision
139
+ self.padding_side = padding_side
140
+
141
+ self.shared_attn = shared_attn
142
+
143
+ if not shared_attn:
144
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
145
+ self.k_proj = nn.Linear(
146
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
147
+ )
148
+ self.v_proj = nn.Linear(
149
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
150
+ )
151
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
152
+
153
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
154
+
155
+ self.linear_operator = LinearAttentionOp(
156
+ layer_idx=layer_idx,
157
+ operator_mode=operator_mode,
158
+ use_linear_checkpoint=use_linear_checkpoint,
159
+ recurrent_config=recurrent_config,
160
+ max_chunk_size=max_chunk_size,
161
+ linear_cache=linear_cache,
162
+ linear_precision=linear_precision,
163
+ )
164
+ self.bidirectional = bidirectional
165
+ # Causal average pooling for gating
166
+ self.pooling_config = pooling_config
167
+ self.pool_g = CausalAvgPool1d(
168
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ x: Union[List[torch.Tensor], torch.Tensor],
174
+ attn_mask: Optional[torch.Tensor] = None,
175
+ out_proj: Optional[nn.Module] = None,
176
+ **kwargs: Any,
177
+ ) -> torch.Tensor:
178
+ """
179
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
180
+ """
181
+
182
+ if not self.shared_attn:
183
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
184
+ # Projections
185
+ q = self.q_proj(hidden_states)
186
+ k = self.k_proj(hidden_states)
187
+ v = self.v_proj(hidden_states)
188
+ out_proj = self.out_proj
189
+ else:
190
+ # Shared attention <=> no projections here
191
+ q, k, v = x[0], x[1], x[2]
192
+ out_proj = self.out_proj if out_proj is None else out_proj
193
+
194
+ # get dtype and device
195
+ final_dtype, final_device = q.dtype, q.device
196
+ # Masking if needed
197
+ if attn_mask is not None:
198
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
199
+
200
+ # Forget and Write Gating for linear attn (abusive term)
201
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
202
+
203
+ # Reshape for multi-head
204
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
205
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
206
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
207
+
208
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
209
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
210
+
211
+ # Repeat for GQA
212
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
213
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
214
+
215
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
216
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
217
+
218
+ ## DeltaNet-style: Silu activation and normalization
219
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
220
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
221
+
222
+ ## linear stability part
223
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
224
+
225
+ # Apply sigmoid to forget and write gates
226
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
227
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
228
+
229
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
230
+ q, k, v, f_g, w_g = (
231
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
232
+ )
233
+ g = (f_g, w_g)
234
+
235
+ # Linear Attention Core, output: [B, H, S, d]
236
+ if self.bidirectional: # Work only with uncausal attention
237
+ # Forward direction
238
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
239
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
240
+ kwargs_bwd = kwargs.copy()
241
+ kwargs_bwd["use_cache"] = False
242
+ out_backward = self.linear_operator(
243
+ torch.flip(q, dims=[2]),
244
+ torch.flip(k, dims=[2]),
245
+ torch.flip(v, dims=[2]),
246
+ tuple(torch.flip(t, dims=[2]) for t in g),
247
+ **kwargs_bwd,
248
+ )
249
+ # Flip the output back to restore proper order
250
+ out_backward = torch.flip(out_backward, dims=[2])
251
+ # Fusion: here, simple addition
252
+ out = out_forward + out_backward
253
+ else:
254
+ out = self.linear_operator(q, k, v, g, **kwargs)
255
+
256
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
257
+ out = rearrange(out, "b h s d -> b s (h d)")
258
+ # Normalize output (RMS norm). Note: bidirectional compatibility
259
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
260
+ # Ensure dtype and device consistency
261
+ out = out.to(dtype=final_dtype, device=final_device)
262
+ # Apply output projection
263
+ out = out_proj(out) # [B, S, D]
264
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
265
+ # Apply dropout if specified
266
+ if self.dropout is not None:
267
+ out = self.dropout(out)
268
+ return out
269
+
270
+
271
+ class LiZAttention(nn.Module):
272
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
273
+
274
+ def __init__(
275
+ self,
276
+ base_attn: nn.Module,
277
+ layer_idx: int,
278
+ base_config: PretrainedConfig, # Backbone Config
279
+ linear_cache: Optional[LCache] = None,
280
+ operator_mode: str = "delta_rule",
281
+ use_linear_checkpoint: bool = False,
282
+ recurrent_config: Optional[Dict[str, Any]] = None,
283
+ max_self_attn_length: Optional[int] = None, # unnecessary
284
+ base_scale_attn: bool = False,
285
+ mag_weight: float = 0.5,
286
+ cross_gate: bool = False,
287
+ max_chunk_size: int = 64,
288
+ linear_precision: Union[str, torch.dtype] = "float32",
289
+ padding_side: str = "right", # for tokenizer
290
+ disable_linear_attn: bool = False,
291
+ bidirectional: bool = False, # if True, use bidirectional attention
292
+ pooling_config: Optional[Dict[str, Any]] = None,
293
+ ):
294
+ super().__init__()
295
+ if isinstance(linear_precision, str):
296
+ linear_precision = getattr(torch, linear_precision)
297
+ self.linear_precision = linear_precision
298
+ self.base_attn: nn.Module = base_attn
299
+ self.base_config = base_config
300
+ self.layer_idx = layer_idx
301
+ self.max_self_attn_length = max_self_attn_length
302
+ self.base_scale_attn = base_scale_attn
303
+ self.mag_weight = mag_weight
304
+ self.cross_gate = cross_gate
305
+ self.max_chunk_size = max_chunk_size
306
+ self.linear_precision = linear_precision
307
+ self.padding_side = padding_side
308
+ self.disable_linear_attn = disable_linear_attn
309
+
310
+ (
311
+ self.num_heads,
312
+ self.head_dim,
313
+ self.num_key_value_heads,
314
+ self.num_key_value_groups,
315
+ self.hidden_dim,
316
+ ) = self._get_attention_parameters(base_attn, base_config)
317
+ self.scaling = self.head_dim**-0.5
318
+
319
+ self.linear_attn = LinearAttention(
320
+ layer_idx=layer_idx,
321
+ shared_attn=True,
322
+ operator_mode=operator_mode,
323
+ use_linear_checkpoint=use_linear_checkpoint,
324
+ recurrent_config=recurrent_config,
325
+ hidden_dim=self.hidden_dim,
326
+ num_heads=self.num_heads,
327
+ head_dim=self.head_dim,
328
+ num_key_value_heads=self.num_key_value_heads,
329
+ num_key_value_groups=self.num_key_value_groups,
330
+ linear_precision=linear_precision,
331
+ linear_cache=linear_cache,
332
+ max_chunk_size=max_chunk_size,
333
+ padding_side=padding_side,
334
+ bidirectional=bidirectional,
335
+ pooling_config=pooling_config,
336
+ )
337
+
338
+ def _get_attention_parameters(
339
+ self, base_attn: nn.Module, base_config: PretrainedConfig
340
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
341
+ """Retrieve the attention parameters from the base attention module."""
342
+ # first order base attention module and second order config
343
+ num_heads = (
344
+ getattr(base_attn, "num_heads", None)
345
+ or getattr(base_attn, "num_q_heads", None)
346
+ or getattr(base_config, "num_heads", None)
347
+ or getattr(base_config, "num_attention_heads", None)
348
+ )
349
+ head_dim = (
350
+ getattr(base_attn, "head_dim", None)
351
+ or getattr(base_attn, "attention_head_size", None)
352
+ or getattr(base_config, "head_dim", None)
353
+ or (
354
+ getattr(base_config, "hidden_size", None) // num_heads
355
+ if num_heads and getattr(base_config, "hidden_size", None)
356
+ else None
357
+ )
358
+ )
359
+ num_key_value_heads = (
360
+ getattr(base_attn, "num_kv_heads", None)
361
+ or getattr(base_attn, "num_k_heads", None)
362
+ or getattr(base_config, "num_key_value_heads", None)
363
+ or num_heads # fallback
364
+ )
365
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
366
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
367
+ )
368
+ hidden_dim = getattr(base_config, "hidden_size", None) or head_dim * num_heads
369
+ return (
370
+ num_heads,
371
+ head_dim,
372
+ num_key_value_heads,
373
+ num_key_value_groups,
374
+ hidden_dim,
375
+ )
376
+
377
+ def _apply_shared_projections(
378
+ self, hidden_states: torch.Tensor
379
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
380
+ base_attn = self.base_attn
381
+ if hasattr(base_attn, "q_proj"):
382
+ # LLama, OLMO and Mistral style
383
+ q = base_attn.q_proj(hidden_states)
384
+ k = base_attn.k_proj(hidden_states)
385
+ v = base_attn.v_proj(hidden_states)
386
+ out_proj = base_attn.o_proj
387
+ elif hasattr(base_attn, "qkv_proj"):
388
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
389
+ qkv = base_attn.qkv_proj(hidden_states)
390
+ q, k, v = split_qkv(base_attn, qkv)
391
+ out_proj = base_attn.out_proj
392
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
393
+ # GPT-2 style
394
+ qkv = base_attn.c_attn(hidden_states)
395
+ q, k, v = qkv.chunk(3, dim=-1)
396
+ out_proj = base_attn.c_proj
397
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
398
+ # BERT - ViT
399
+ q = base_attn.query(hidden_states)
400
+ k = base_attn.key(hidden_states)
401
+ v = base_attn.value(hidden_states)
402
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
403
+ else:
404
+ raise ValueError("Unsupported attention module: cannot find projections.")
405
+ # Ensure stability
406
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
407
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
408
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
409
+ return q, k, v, out_proj
410
+
411
+ def _process_self_attn(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ attention_mask: Optional[torch.Tensor],
415
+ kwargs,
416
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
417
+ """Process the self-attention part (with truncation)."""
418
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
419
+ hidden_states, attention_mask = truncate_attention_mask(
420
+ hidden_states, attention_mask, self.max_self_attn_length
421
+ )
422
+
423
+ if kwargs.get("position_embeddings", None) is not None:
424
+ cos, sin = kwargs["position_embeddings"]
425
+ cos = cos[:, -self.max_self_attn_length :]
426
+ sin = sin[:, -self.max_self_attn_length :]
427
+ kwargs["position_embeddings"] = (cos, sin)
428
+
429
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
430
+ # cache management
431
+ if (
432
+ len(kwargs["past_key_value"]) > self.layer_idx
433
+ and self.layer_idx == 0
434
+ ):
435
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
436
+
437
+ # Ensure attention mask is of the correct dtype and device
438
+ if attention_mask is not None:
439
+ attention_mask = attention_mask.to(
440
+ dtype=hidden_states.dtype, device=hidden_states.device
441
+ )
442
+ # Standard attention (mask and rotation is applied inside)
443
+ base_attn_outputs = self.base_attn(
444
+ hidden_states,
445
+ attention_mask=attention_mask,
446
+ **kwargs,
447
+ )
448
+
449
+ if isinstance(base_attn_outputs, tuple):
450
+ if len(base_attn_outputs) == 3:
451
+ o_base, attn_weights, present_key_value = base_attn_outputs
452
+ expected_attn_mode = 3
453
+ elif len(base_attn_outputs) == 2:
454
+ o_base, attn_weights = base_attn_outputs
455
+ present_key_value, expected_attn_mode = None, 2
456
+ else:
457
+ raise ValueError(
458
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
459
+ )
460
+ else:
461
+ o_base = base_attn_outputs
462
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
463
+ # Ensure stability
464
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
465
+ return o_base, attn_weights, present_key_value, expected_attn_mode
466
+
467
+ def _prepare_attn_mixin(
468
+ self,
469
+ o_lin: torch.Tensor,
470
+ o_base: torch.Tensor,
471
+ tensor_dtype: torch.dtype,
472
+ eps: float = 1e-5,
473
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
474
+ """Prepare linear attn for mixing with self attn."""
475
+ # Force cast typing, shape : [b n (h d)]
476
+ o_lin = o_lin.to(tensor_dtype)
477
+ o_base = o_base.to(tensor_dtype)
478
+ # feature scaling
479
+ if self.base_scale_attn:
480
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
481
+ o_lin = scaler * o_lin
482
+ return o_lin, o_base
483
+
484
+ def _apply_mag(
485
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
486
+ ) -> torch.Tensor:
487
+ """Apply the MAG strategy"""
488
+ # Left-Padding management
489
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
490
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
491
+ linear_attention, softmax_attention = (
492
+ linear_attention[:, -left_trunc:],
493
+ softmax_attention[:, -left_trunc:],
494
+ )
495
+ # NAM : Neural Attention Mixer (with graph forcing)
496
+ mag_weight = torch.tensor(
497
+ self.mag_weight,
498
+ dtype=softmax_attention.dtype,
499
+ device=softmax_attention.device,
500
+ )
501
+ softmax_weighted = (1 - mag_weight) * softmax_attention
502
+ linear_weighted = mag_weight * linear_attention
503
+ if self.cross_gate:
504
+ output_attention = (
505
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
506
+ ) # complex cross product (unlinear interaction)
507
+ else:
508
+ output_attention = softmax_weighted + linear_weighted # classic
509
+
510
+ if torch.allclose(softmax_weighted, output_attention):
511
+ logger.info(
512
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
513
+ self.layer_idx,
514
+ )
515
+ # Final output
516
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
517
+
518
+ def forward(
519
+ self,
520
+ hidden_states: torch.Tensor,
521
+ attention_mask: Optional[torch.Tensor] = None,
522
+ **kwargs,
523
+ ) -> torch.Tensor:
524
+ """Mix linear and self attention forward"""
525
+ device = hidden_states.device
526
+ tensor_dtype = hidden_states.dtype
527
+ self.base_attn.to(device)
528
+
529
+ if self.training:
530
+ kwargs.pop("past_key_value", None)
531
+ kwargs["use_cache"] = False
532
+ elif "use_cache" not in kwargs:
533
+ kwargs.pop("past_key_value", None)
534
+ kwargs["use_cache"] = False
535
+
536
+ kwargs.pop("position_ids", None) # obsolete
537
+
538
+ # Apply shared projections
539
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
540
+
541
+ # Apply linear attention to hidden states
542
+ o_lin = self.linear_attn(
543
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
544
+ )
545
+
546
+ # Process self attn with truncation
547
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
548
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
549
+ )
550
+
551
+ # Prepare output mixing
552
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
553
+
554
+ # Apply Memory as Gate in self-attention (with length management and ablation)
555
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
556
+
557
+ # Return output following transformer convention
558
+ if expected_attn_mode == 3:
559
+ return out, attn_weights, present_key_value
560
+ if expected_attn_mode == 2:
561
+ return out, attn_weights
562
+ return out
563
+
564
+
565
+ def load_tptt_safetensors(
566
+ repo_or_path: str,
567
+ model: Union[PreTrainedModel, PeftModel],
568
+ subfolder: Optional[str] = None,
569
+ token: Optional[str] = None,
570
+ ) -> Union[PreTrainedModel, PeftModel]:
571
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
572
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
573
+ fname = "adapter_model.safetensors"
574
+ # subfolder management
575
+ if subfolder:
576
+ repo_or_path_norm = os.path.normpath(repo_or_path)
577
+ subfolder_norm = os.path.normpath(subfolder)
578
+ if not repo_or_path_norm.endswith(subfolder_norm):
579
+ fname = f"{subfolder}/{fname}" if subfolder else fname
580
+ # Find file path
581
+ if os.path.isdir(repo_or_path):
582
+ path = os.path.join(repo_or_path, fname)
583
+ if not os.path.exists(path):
584
+ return model
585
+ else:
586
+ if fname not in list_repo_files(repo_or_path, token=token):
587
+ return model
588
+ path = hf_hub_download(repo_or_path, fname, token=token)
589
+
590
+ # Load weights from safetensors
591
+ with safe_open(path, framework="pt") as f:
592
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
593
+
594
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
595
+ def adapt_keys(sd, model):
596
+ model_keys = list(model.state_dict().keys())
597
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
598
+ prefix = "tptt_model.base_model."
599
+ elif any(k.startswith("base_model.") for k in model_keys):
600
+ prefix = "base_model."
601
+ else:
602
+ prefix = ""
603
+
604
+ has_base_attn = any(".base_attn." in k for k in model_keys)
605
+
606
+ def adapt_key(k):
607
+ k_ = k if k.startswith(prefix) else prefix + k
608
+ # first, verify and modify base_attn (LiZA)
609
+ if ".base_attn." in k_ and not has_base_attn:
610
+ k_ = k_.replace(".base_attn.", ".")
611
+ # change LoRA if needed
612
+ if (
613
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
614
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
615
+ k_ = k_.replace(".weight", ".default.weight")
616
+ return k_
617
+
618
+ return {adapt_key(k): v for k, v in sd.items()}
619
+
620
+ state_dict = adapt_keys(state_dict, model)
621
+
622
+ # Cast tensors to the expected dtype of the model parameters
623
+ model_state_dict = model.state_dict()
624
+ for k, v in state_dict.items():
625
+ if k in model_state_dict:
626
+ expected_dtype = model_state_dict[k].dtype
627
+ if v.dtype != expected_dtype:
628
+ state_dict[k] = v.to(expected_dtype)
629
+
630
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
631
+
632
+ # Load into model
633
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
634
+ missing_lora = [k for k in missing if "lora" in k]
635
+ if missing_lora:
636
+ logger.warning("Missing keys: %s", missing_lora)
637
+ if unexpected:
638
+ logger.warning("Unexpected keys: %s", unexpected)
639
+ return model
640
+
641
+
642
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
643
+ model: nn.Module,
644
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
645
+ linear_cache: Optional[LCache] = None,
646
+ liza_attention: nn.Module = LiZAttention,
647
+ target_modules_names: Optional[list[str]] = None,
648
+ operator_mode: str = "delta_rule",
649
+ use_linear_checkpoint: bool = False,
650
+ recurrent_config: Optional[Dict[str, Any]] = None,
651
+ base_scale_attn: bool = False,
652
+ mag_weight: float = 0.5,
653
+ cross_gate: bool = False,
654
+ max_chunk_size: int = 64,
655
+ linear_precision: torch.dtype = torch.float32,
656
+ max_self_attn_length: Optional[int] = None, # unnecessary
657
+ padding_side: str = "right", # for tokenizer
658
+ bidirectional: bool = False, # if True, use bidirectional attention
659
+ pooling_config: Optional[Dict[str, Any]] = None,
660
+ **kwargs, # quickfix unexpected arguments
661
+ ) -> Tuple[PreTrainedModel, LCache]:
662
+ """Replace target modules in a model with LiZAttention."""
663
+ if target_modules_names is None:
664
+ target_modules_names = ["attn", "self_attn", "attention"]
665
+ # Find target modules by suffix (e.g., "attn", "attention")
666
+ target_modules_names = [
667
+ name
668
+ for name, _ in model.named_modules()
669
+ if any(name.endswith(suffix) for suffix in target_modules_names)
670
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
671
+ ]
672
+ if not target_modules_names:
673
+ raise ValueError(
674
+ f"Target modules '{target_modules_names}' not found in the model."
675
+ )
676
+ # Prepare recurrent config
677
+ linear_cache = linear_cache or LCache()
678
+ # Inject LiZAttention into the model
679
+ for name, _ in model.named_modules():
680
+ if name in target_modules_names:
681
+ parent = model
682
+ *path, last = name.split(".")
683
+ for p in path:
684
+ parent = getattr(parent, p)
685
+ layer_idx = extract_layer_idx(name)
686
+ setattr(
687
+ parent,
688
+ last,
689
+ liza_attention(
690
+ getattr(parent, last),
691
+ layer_idx=layer_idx,
692
+ base_config=base_config,
693
+ linear_cache=linear_cache,
694
+ operator_mode=operator_mode,
695
+ use_linear_checkpoint=use_linear_checkpoint,
696
+ recurrent_config=recurrent_config,
697
+ max_self_attn_length=max_self_attn_length,
698
+ base_scale_attn=base_scale_attn,
699
+ mag_weight=mag_weight,
700
+ cross_gate=cross_gate,
701
+ max_chunk_size=max_chunk_size,
702
+ linear_precision=linear_precision,
703
+ padding_side=padding_side,
704
+ bidirectional=bidirectional,
705
+ pooling_config=pooling_config,
706
+ ),
707
+ )
708
+ return model, linear_cache
709
+
710
+
711
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
712
+ """Save trainable LoRA/Specific weights and adapting key names"""
713
+ # 1. Get the full state_dict
714
+ all_sd = model.state_dict()
715
+
716
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
717
+ trainable_keys = [
718
+ name for name, param in model.named_parameters() if param.requires_grad
719
+ ] # Also, you can manually select specific keys in model after load
720
+
721
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
722
+ to_save = {
723
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
724
+ for k in trainable_keys
725
+ }
726
+
727
+ # 4. Save the filtered adapters to a safetensors file
728
+ if to_save:
729
+ os.makedirs(os.path.dirname(path), exist_ok=True)
730
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
731
+ save_file(to_save, os.path.join(path, name))
732
+
733
+
734
+ class TpttModel(PreTrainedModel):
735
+ """
736
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
737
+ Handles only architecture and weights.
738
+ """
739
+
740
+ config_class = TpttConfig
741
+
742
+ def __init__(
743
+ self,
744
+ config: TpttConfig,
745
+ **kwargs,
746
+ ):
747
+ """
748
+ Initialize TpttModel with a given config and backbone.
749
+ Injects LiZA attention modules into the backbone.
750
+ """
751
+ super().__init__(config, **kwargs)
752
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
753
+
754
+ # 1. Load backbone (with subfolder management) :
755
+ kwargs_bb = kwargs.copy()
756
+ if config.base_model_subfolder is not None:
757
+ kwargs_bb["subfolder"] = config.base_model_subfolder
758
+ else:
759
+ kwargs_bb.pop("subfolder", None)
760
+
761
+ if config.model_task == "causal_lm":
762
+ tptt_model = AutoModelForCausalLM.from_pretrained(
763
+ config.base_model_name, **kwargs_bb
764
+ )
765
+ else:
766
+ tptt_model = AutoModel.from_pretrained(config.base_model_name, **kwargs_bb)
767
+
768
+ # 2. Inject LiZA attention
769
+ self.linear_cache = LCache()
770
+ tptt_model, self.linear_cache = get_tptt_model(
771
+ tptt_model, config, self.linear_cache, **config.to_dict()
772
+ )
773
+
774
+ # 3. Apply LoRA/Specific if present and configured
775
+ if config.lora_config is not None:
776
+ lora_config_obj = LoraConfig(**config.lora_config)
777
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
778
+ else:
779
+ # Doesn't work if quantization is applied !
780
+ tptt_model = set_trainable_parameters(tptt_model)
781
+
782
+ # 4. Load safetensor if tptt/peft adaptor in repo
783
+ if repo_or_path:
784
+ tptt_model = load_tptt_safetensors(
785
+ repo_or_path,
786
+ tptt_model,
787
+ subfolder=kwargs.get("subfolder", None),
788
+ token=kwargs.get("token", None),
789
+ )
790
+ self.tptt_model = tptt_model
791
+
792
+ def forward(
793
+ self,
794
+ input_ids: Optional[torch.LongTensor] = None,
795
+ attention_mask: Optional[torch.Tensor] = None,
796
+ labels: Optional[torch.LongTensor] = None,
797
+ **kwargs,
798
+ ):
799
+ """Forward pass. All arguments are passed to the underlying base model."""
800
+ if self.training:
801
+ kwargs["use_cache"] = False
802
+ kwargs.pop("num_items_in_batch", None)
803
+ elif "use_cache" not in kwargs: # evaluation
804
+ kwargs.pop("num_items_in_batch", None)
805
+ kwargs["use_cache"] = False
806
+ return self.tptt_model(
807
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
808
+ )
809
+
810
+ def generate(self, *args, **kwargs):
811
+ """Delegate the generate call to the backbone model, which supports generation"""
812
+ return self.tptt_model.generate(*args, **kwargs)
813
+
814
+ def save_pretrained(self, path: str, **kwargs):
815
+ """Save model weights, config, and source code to the given path."""
816
+ # 0. Save complete tptt config (with or without LoRA)
817
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
818
+ self._adjust_save_strategy(path, **kwargs)
819
+ # 1. Save true weights and adapte keys
820
+ save_tptt_safetensors(self, path)
821
+ # 2. Copy Python files for trust_remote_code
822
+ self._copy_source_files(path, **kwargs)
823
+
824
+ def _adjust_save_strategy(self, path: str, **kwargs):
825
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
826
+ if isinstance(self.tptt_model, PeftModel):
827
+ self.tptt_model.save_pretrained(path, **kwargs)
828
+ safetensor_path = os.path.join(path, "model.safetensors")
829
+ if os.path.exists(safetensor_path):
830
+ os.remove(safetensor_path)
831
+ adapter_path = os.path.join(path, "adapter_config.json")
832
+ if os.path.exists(adapter_path):
833
+ os.remove(adapter_path)
834
+
835
+ def _copy_source_files(self, target_path: str, **kwargs):
836
+ """Copy all .py files from package directory for trust_remote_code."""
837
+ src_dir = os.path.dirname(os.path.abspath(__file__))
838
+ dst_dir = (
839
+ f"./{str(Path(target_path).parts[0])}"
840
+ if kwargs.get("subfolder", False)
841
+ else target_path
842
+ )
843
+ for fname in os.listdir(src_dir):
844
+ if fname.endswith(".py"):
845
+ src = os.path.join(src_dir, fname)
846
+ dst = os.path.join(dst_dir, fname)
847
+ shutil.copy2(src, dst)
848
+
849
+ def retie_lm_after_load(self, **kwargs):
850
+ """Re-link lm_head after loading external weights."""
851
+ embed_lm = find_embedding_lm(self.tptt_model)
852
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
853
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
854
+ self.tptt_model.lm_head = nn.Linear(
855
+ embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
856
+ )
857
+ if kwargs.get("tie_word_embeddings", True):
858
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
859
+ logger.info("Weights of lm_head have been shared with embedding.")
860
+ else:
861
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
862
+ logger.info("Weights of lm_head have been cloned from the embedding.")
863
+
864
+ @classmethod
865
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
866
+ """Custom from_pretrained that accepts the standard positional argument"""
867
+ config = kwargs.pop("config", None)
868
+ repo_or_path = (
869
+ pretrained_model_name_or_path
870
+ or kwargs.pop("pretrained_model_name_or_path", None)
871
+ or kwargs.pop("repo_or_path", None)
872
+ or (getattr(config, "_base_path", None) if config else None)
873
+ or (getattr(config, "_name_or_path", None) if config else None)
874
+ )
875
+
876
+ if config is None and repo_or_path is not None:
877
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
878
+ model = cls(config, *model_args, **kwargs)
879
+ model.retie_lm_after_load(**kwargs)
880
+ return model
881
+
882
+
883
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
884
+
885
+
886
+ class LinearAttentionOp(nn.Module):
887
+ """Base class for linear attention operators."""
888
+
889
+ def __init__(
890
+ self,
891
+ layer_idx: int,
892
+ operator_mode: str = "delta_rule",
893
+ use_linear_checkpoint: bool = False,
894
+ recurrent_config: Optional[dict] = None,
895
+ max_chunk_size: int = 64,
896
+ linear_cache: Optional[LCache] = None,
897
+ linear_precision: torch.dtype = torch.float32,
898
+ ):
899
+ super().__init__()
900
+ self.layer_idx = layer_idx
901
+ if recurrent_config is None:
902
+ operator_mode = "delta_rule" # force default operator mode if no config
903
+ recurrent_config = {
904
+ "order": 1,
905
+ "gate_type": "k",
906
+ "linear": True,
907
+ "trick": "derivative",
908
+ }
909
+ self.operator_mode = operator_mode
910
+ self.use_linear_checkpoint = use_linear_checkpoint
911
+
912
+ self.order = recurrent_config["order"]
913
+ self.gate_type = recurrent_config["gate_type"]
914
+ self.linear = recurrent_config["linear"]
915
+ self.trick = recurrent_config["trick"]
916
+
917
+ self.max_chunk_size = max_chunk_size
918
+ self.linear_cache = linear_cache or LCache()
919
+ self.linear_precision = linear_precision
920
+
921
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
922
+ """
923
+ Compute the gating tensor according to the gate_type.
924
+ """
925
+ if self.gate_type == "k":
926
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
927
+ if self.gate_type == "v":
928
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
929
+ if self.gate_type == "kv":
930
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
931
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
932
+
933
+ def get_cache(self, use_cache: bool) -> Tuple[
934
+ Optional[torch.Tensor],
935
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
936
+ ]:
937
+ """
938
+ Retrieve recurrent state and qkv buffers from the cache.
939
+ """
940
+ if not use_cache:
941
+ return None, None
942
+ last_state = self.linear_cache[self.layer_idx]
943
+ if last_state is not None:
944
+ recurrent_state = last_state.get("recurrent_state", None)
945
+ qkv_buffers = last_state.get("qkv", None)
946
+ else:
947
+ recurrent_state = None
948
+ qkv_buffers = None
949
+ return recurrent_state, qkv_buffers
950
+
951
+ def save_cache(
952
+ self,
953
+ use_cache: bool,
954
+ q: torch.Tensor,
955
+ k: torch.Tensor,
956
+ v: torch.Tensor,
957
+ gate: torch.Tensor,
958
+ state: torch.Tensor,
959
+ ) -> None:
960
+ """
961
+ Save the recurrent state and qkv buffers to the cache.
962
+ """
963
+ if not use_cache:
964
+ return
965
+ if self.order > 1:
966
+ qkv_buffers = (
967
+ q[:, :, -(self.order - 1) :, :],
968
+ k[:, :, -(self.order - 1) :, :],
969
+ v[:, :, -(self.order - 1) :, :],
970
+ gate[:, :, -(self.order - 1) :, :],
971
+ )
972
+ else:
973
+ qkv_buffers = None
974
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
975
+
976
+ def forward(
977
+ self,
978
+ q: torch.Tensor,
979
+ k: torch.Tensor,
980
+ v: torch.Tensor,
981
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
982
+ **kwargs,
983
+ ) -> torch.Tensor:
984
+ """
985
+ Forward pass for the attention operator.
986
+ """
987
+ # Ensure linear_precision for numerical stability (float32)
988
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
989
+ if isinstance(beta, (tuple, list)):
990
+ beta = tuple(b.to(self.linear_precision) for b in beta)
991
+ else:
992
+ beta = beta.to(self.linear_precision)
993
+
994
+ gate = self.compute_gate(beta)
995
+
996
+ # Retrieve cache if needed
997
+ use_cache = kwargs.get("use_cache", False)
998
+ use_checkpoint = not (use_cache) and self.use_linear_checkpoint
999
+ recurrent_state, qkvb = self.get_cache(use_cache)
1000
+
1001
+ if qkvb is not None and qkvb[0].shape == q.shape:
1002
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
1003
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
1004
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
1005
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
1006
+ self.linear_precision
1007
+ )
1008
+
1009
+ output, state = self.chunk_delta_product_forward(
1010
+ q,
1011
+ k,
1012
+ v,
1013
+ gate,
1014
+ self.max_chunk_size,
1015
+ n=self.order,
1016
+ trick=self.trick,
1017
+ linear=self.linear,
1018
+ initial_state=recurrent_state,
1019
+ use_checkpoint=use_checkpoint,
1020
+ linear_precision=self.linear_precision,
1021
+ )
1022
+
1023
+ # Save cache if needed
1024
+ self.save_cache(use_cache, q, k, v, gate, state)
1025
+
1026
+ return output
1027
+
1028
+ @staticmethod
1029
+ def chunk_delta_product_forward(
1030
+ query: torch.Tensor,
1031
+ key: torch.Tensor,
1032
+ value: torch.Tensor,
1033
+ beta_gate: torch.Tensor,
1034
+ chunk_size: int,
1035
+ n: int = 1,
1036
+ trick: str = "derivative",
1037
+ linear: bool = True,
1038
+ initial_state: Optional[torch.Tensor] = None,
1039
+ use_checkpoint: bool = True,
1040
+ linear_precision: torch.dtype = torch.float32,
1041
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1042
+ """
1043
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1044
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
1045
+ """
1046
+
1047
+ # --- Main chunk_delta_product_forward logic ---
1048
+
1049
+ batch_size, num_heads, seq_len, head_dim = query.shape
1050
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1051
+ num_chunks = seq_len // chunk_size
1052
+
1053
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1054
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1055
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1056
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
1057
+
1058
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1059
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1060
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1061
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1062
+
1063
+ k_beta = k_chunks * beta_chunks
1064
+ v_beta = v_chunks * beta_chunks
1065
+
1066
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1067
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1068
+
1069
+ # size : N = chunk_size * n
1070
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1071
+
1072
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1073
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1074
+
1075
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1076
+ if initial_state is not None and initial_state.shape == state_shape:
1077
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1078
+ else:
1079
+ state = torch.full(
1080
+ state_shape,
1081
+ fill_value=1e-6, # stability if unlinear activation
1082
+ device=query.device,
1083
+ dtype=linear_precision,
1084
+ )
1085
+
1086
+ output, final_state = sequential_delta_product_scan(
1087
+ q_chunks.to(dtype=linear_precision),
1088
+ w.to(dtype=linear_precision),
1089
+ u.to(dtype=linear_precision),
1090
+ n,
1091
+ linear,
1092
+ chunk_size,
1093
+ state.to(dtype=linear_precision),
1094
+ linear_precision=linear_precision,
1095
+ use_checkpoint=use_checkpoint,
1096
+ )
1097
+
1098
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1099
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
1100
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
1101
+
1102
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1103
+
1104
+
1105
+ def sequential_delta_product_scan(
1106
+ q_chunks: torch.Tensor,
1107
+ w: torch.Tensor,
1108
+ u: torch.Tensor,
1109
+ n_orders: int,
1110
+ linear_activation: bool,
1111
+ current_chunk_size: int,
1112
+ initial_recurrent_state: torch.Tensor,
1113
+ linear_precision: torch.dtype,
1114
+ use_checkpoint: bool,
1115
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1116
+ """
1117
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1118
+ Implements the per-token Householder state updates.
1119
+ """
1120
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1121
+ output_inner = torch.empty_like(q_chunks)
1122
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1123
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1124
+
1125
+ def process_one_chunk(
1126
+ q_chunk_params: torch.Tensor,
1127
+ w_chunk_params: torch.Tensor,
1128
+ u_chunk_params: torch.Tensor,
1129
+ h_0_base: torch.Tensor,
1130
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1131
+ """
1132
+ Process a single chunk (with per-token state for n_orders > 1).
1133
+ """
1134
+ o_intra_current_chunk = torch.zeros(
1135
+ batch,
1136
+ head,
1137
+ chunk_n_total,
1138
+ dim,
1139
+ device=q_chunk_params.device,
1140
+ dtype=linear_precision,
1141
+ )
1142
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1143
+ current_accumulated_state_per_token = (
1144
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1145
+ ) # [B, H, current_chunk_size, D, D]
1146
+
1147
+ for step in range(n_orders):
1148
+ idx_virtual_tokens = (
1149
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1150
+ * n_orders
1151
+ + step
1152
+ )
1153
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1154
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1155
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1156
+
1157
+ state_input_for_this_step = current_accumulated_state_per_token
1158
+
1159
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1160
+ k_trans_h_old = (
1161
+ torch.matmul(
1162
+ w_s.unsqueeze(-2),
1163
+ state_input_for_this_step,
1164
+ )
1165
+ .squeeze(-2)
1166
+ .to(dtype=linear_precision)
1167
+ )
1168
+
1169
+ u_val = u_s - k_trans_h_old
1170
+
1171
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1172
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1173
+ .squeeze(-2)
1174
+ .to(dtype=linear_precision)
1175
+ )
1176
+
1177
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1178
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1179
+ dtype=linear_precision
1180
+ )
1181
+
1182
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1183
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1184
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1185
+ dtype=linear_precision
1186
+ )
1187
+ # Return all needed for next chunk
1188
+ return (
1189
+ o_intra_current_chunk,
1190
+ o_inter_current_chunk,
1191
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1192
+ )
1193
+
1194
+ for chunk_idx_inner in range(num_chunks_inner):
1195
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1196
+ w_chunk_params = w[:, :, chunk_idx_inner]
1197
+ u_chunk_params = u[:, :, chunk_idx_inner]
1198
+
1199
+ # Checkpointed call if training
1200
+ call = (
1201
+ partial(checkpoint, use_reentrant=False)
1202
+ if use_checkpoint
1203
+ else lambda f, *a: f(*a)
1204
+ )
1205
+ o_intra, o_inter, h_0_base = call(
1206
+ process_one_chunk,
1207
+ q_chunk_params,
1208
+ w_chunk_params,
1209
+ u_chunk_params,
1210
+ h_0_base,
1211
+ )
1212
+ if not linear_activation: # unlinear activation between chunks
1213
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1214
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1215
+
1216
+ return output_inner, h_0_base
1217
+
1218
+
1219
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1220
+ """Unlinear activation between chunk"""
1221
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1222
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1223
+ return (x / scale) * x_gelu
1224
+
1225
+
1226
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1227
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1228
+ batch_size, num_heads, _, head_dim = x.shape
1229
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1230
+
1231
+
1232
+ def expand_virtual_tokens(
1233
+ x: torch.Tensor, n: int, mode: str = "derivative"
1234
+ ) -> torch.Tensor:
1235
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1236
+ batch_size, num_heads, seq_len, head_dim = x.shape
1237
+ device, dtype = x.device, x.dtype
1238
+
1239
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1240
+ """Expand tokens using the derivative trick."""
1241
+ x_pad = torch.cat(
1242
+ [
1243
+ torch.zeros(
1244
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1245
+ ),
1246
+ x,
1247
+ ],
1248
+ dim=2,
1249
+ )
1250
+ coeffs = torch.tensor(
1251
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1252
+ device=device,
1253
+ dtype=dtype,
1254
+ )
1255
+ coeffs /= coeffs.norm(p=1)
1256
+ return (
1257
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1258
+ .flip(-1)
1259
+ .permute(0, 1, 2, 4, 3)
1260
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1261
+ )
1262
+
1263
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1264
+ """Expand tokens using the rotative trick."""
1265
+ d_parity = head_dim // 2
1266
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1267
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1268
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1269
+ if head_dim % 2:
1270
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1271
+ else:
1272
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1273
+ x_pairs = x_pairs.unsqueeze(3).expand(
1274
+ batch_size, num_heads, seq_len, n, d_parity, 2
1275
+ )
1276
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1277
+ x0r = x0 * cos - x1 * sin
1278
+ x1r = x0 * sin + x1 * cos
1279
+ rot = torch.stack([x0r, x1r], -1).reshape(
1280
+ batch_size, num_heads, seq_len, n, d_parity * 2
1281
+ )
1282
+ if head_dim % 2:
1283
+ last = (
1284
+ x[..., -1]
1285
+ .unsqueeze(-1)
1286
+ .unsqueeze(3)
1287
+ .expand(batch_size, num_heads, seq_len, n, 1)
1288
+ )
1289
+ rot = torch.cat([rot, last], -1)
1290
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1291
+
1292
+ if mode == "derivative":
1293
+ return derivative_expand(x)
1294
+ if mode == "rotative":
1295
+ return rotative_expand(x)
1296
+ if mode == "combined":
1297
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1298
+ raise ValueError(f"Unknown mode: {mode}")
1299
+
1300
+
1301
+ def extract_layer_idx(module_name: str) -> int:
1302
+ """Extract the layer index from a module name string."""
1303
+ match = re.search(r"\.(\d+)\.", module_name)
1304
+ if match:
1305
+ return int(match.group(1))
1306
+ return -1
1307
+
1308
+
1309
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1310
+ """Find the embedding weight in a model module."""
1311
+ for _, child in module.named_modules():
1312
+ if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
1313
+ return child.embed_tokens
1314
+ if hasattr(child, "token_embeddings") and hasattr(
1315
+ child.token_embeddings, "weight"
1316
+ ):
1317
+ return child.token_embeddings
1318
+ return None
1319
+
1320
+
1321
+ def set_trainable_parameters(
1322
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1323
+ ) -> PreTrainedModel:
1324
+ """Freeze model parameters except trainable_patterns."""
1325
+ if trainable_patterns is None:
1326
+ trainable_patterns = [
1327
+ "q_proj",
1328
+ "k_proj",
1329
+ "v_proj",
1330
+ "o_proj",
1331
+ "qkv_proj",
1332
+ "out_proj",
1333
+ "c_attn",
1334
+ "c_proj",
1335
+ "query",
1336
+ "key",
1337
+ "value",
1338
+ ]
1339
+
1340
+ for name, param in model.named_parameters():
1341
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
1342
+
1343
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1344
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1345
+ return model
1346
+
1347
+
1348
+ def ensure_stability(
1349
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1350
+ ) -> torch.Tensor:
1351
+ """stability forcing"""
1352
+ dtype = tensor.dtype
1353
+ center = (max_val + min_val) / 2
1354
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1355
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1356
+ return tensor.to(dtype=dtype)
1357
+
1358
+
1359
+ def apply_linear_attention_mask(
1360
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1361
+ ) -> torch.Tensor:
1362
+ """Extract if padding --> [B,S]"""
1363
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1364
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1365
+ else:
1366
+ mask = attention_mask.squeeze(
1367
+ dim=tuple(
1368
+ i
1369
+ for i in range(1, attention_mask.dim())
1370
+ if attention_mask.shape[i] == 1
1371
+ )
1372
+ )
1373
+ # Ensure cast to the same dtype as v and convert to binary mask
1374
+ if not (
1375
+ mask.dtype == torch.bool
1376
+ or (
1377
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1378
+ and mask.max() <= 1
1379
+ and mask.min() >= 0
1380
+ )
1381
+ ):
1382
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1383
+ else:
1384
+ mask = mask.to(v.dtype)
1385
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1386
+ if padding_side == "left":
1387
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1388
+ else: # right padding
1389
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1390
+ return v * mask
1391
+
1392
+
1393
+ def truncate_attention_mask(
1394
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1395
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1396
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
1397
+ seq_dim = 1 # convention: (batch, seq, ...)
1398
+ seq_len = hidden_states.shape[seq_dim]
1399
+ if seq_len > max_length:
1400
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
1401
+ if attention_mask is not None:
1402
+ # mask [batch, seq]
1403
+ if attention_mask.dim() == 2:
1404
+ attention_mask = attention_mask[:, -max_length:]
1405
+ # mask [batch, seq, seq]
1406
+ elif attention_mask.dim() == 3:
1407
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
1408
+ # mask [batch, 1, seq, seq]
1409
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1410
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
1411
+ else:
1412
+ raise ValueError(
1413
+ "No dimension in attention_mask matches sequence length of hidden_states."
1414
+ )
1415
+ return hidden_states, attention_mask
1416
+
1417
+
1418
+ def fast_invert_matrix(
1419
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1420
+ ) -> torch.Tensor:
1421
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1422
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1423
+ chunk_size = tri_tensor.shape[-1]
1424
+
1425
+ for i in range(1, chunk_size):
1426
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1427
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1428
+ ).sum(-2)
1429
+
1430
+ tri_tensor = tri_tensor + torch.eye(
1431
+ chunk_size, dtype=dtype, device=tri_tensor.device
1432
+ )
1433
+ return tri_tensor.to(dtype=dtype)
1434
+
1435
+
1436
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1437
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
1438
+ for c in range(min(chunk_size, total_l), 0, -1):
1439
+ if total_l % c == 0:
1440
+ return c
1441
+ return 1
1442
+
1443
+
1444
+ ## RARELY
1445
+ def split_qkv(
1446
+ base_attn: nn.Module, qkv: torch.Tensor
1447
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1448
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1449
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1450
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1451
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1452
+ head_dim = getattr(base_attn, "head_dim", None)
1453
+
1454
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1455
+ raise ValueError(
1456
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1457
+ )
1458
+
1459
+ q_len = num_q_heads * head_dim
1460
+ k_len = num_k_heads * head_dim
1461
+ v_len = num_v_heads * head_dim
1462
+
1463
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1464
+ return q, k, v
1465
+
1466
+
1467
+ ## OPTIONAL
1468
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1469
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
1470
+ src_size = x.shape[dim]
1471
+ if src_size == target_size:
1472
+ return x
1473
+ x = torch.moveaxis(x, dim, -1)
1474
+ shape = x.shape
1475
+ if src_size < target_size:
1476
+ x = x.reshape(-1, 1, src_size)
1477
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
1478
+ x = x.reshape(*shape[:-1], target_size)
1479
+ else:
1480
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
1481
+ x = F.linear(x, eye) # pylint: disable=not-callable
1482
+ x = torch.moveaxis(x, -1, dim)
1483
+ return x
1484
+
1485
+
1486
+ def soft_clamp(
1487
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1488
+ ) -> torch.Tensor:
1489
+ """Differentiable clamping for stability"""
1490
+ dtype = x.dtype
1491
+ scale = (max_val - min_val) / 2
1492
+ center = (max_val + min_val) / 2
1493
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1494
+
1495
+
1496
+ def describe(x: torch.Tensor, name="tensor") -> None:
1497
+ """Prints the shape, min, max, mean, and std of a tensor."""
1498
+ stats = (x.min(), x.max(), x.mean(), x.std())
1499
+ print(
1500
+ f"{name} shape: {tuple(x.shape)}, "
1501
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1502
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1503
+ + f"dtype: {x.dtype}, device: {x.device}"
1504
+ )
lora_delta_product_m0.5_constant/runs/Aug27_14-19-44_61e44eee8185/events.out.tfevents.1756304386.61e44eee8185.163.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:572bd4fe44a9b92d07140da1219bbc70bd993f74cdde397d60ddb7c3d10b2dad
3
+ size 69803
lora_delta_product_m0.5_constant/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
lora_delta_product_m0.5_constant/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
lora_delta_product_m0.5_constant/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
lora_delta_product_m0.5_constant/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": "</s>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
modeling_tptt.py CHANGED
@@ -1,90 +1,271 @@
 
 
1
  """
2
  This module implements the TPTT model with linear attention (LiZA) and LoRA support.
3
  Author : Fabien FURFARO
 
4
  """
5
 
6
  import logging
 
7
  import os
 
8
  import re
9
  import shutil
10
- from typing import Dict, List, Optional
 
11
 
12
  import torch
13
  import torch.nn.functional as F
14
  from einops import rearrange
15
  from huggingface_hub import hf_hub_download, list_repo_files
16
- from peft import LoraConfig, get_peft_model
17
  from safetensors import safe_open
 
18
  from torch import nn
19
- from transformers import AutoModelForCausalLM, DynamicCache, PreTrainedModel
 
 
 
 
 
 
 
20
  from transformers.configuration_utils import PretrainedConfig
21
 
22
  from .configuration_tptt import TpttConfig
23
 
 
24
 
25
- def import_fla_ops():
26
- """flash linear attention"""
27
- if torch.cuda.is_available():
28
- try:
29
- from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla
30
 
31
- return fused_chunk_gla, fused_recurrent_gla
32
- except ImportError:
33
- return None, None
34
- return None, None
 
 
 
 
35
 
 
 
 
36
 
37
- fused_chunk_gla, fused_recurrent_gla = import_fla_ops() # TODO: add all ops
 
 
 
 
 
 
 
 
 
 
38
 
39
- logger = logging.getLogger(__name__) # monitoring
 
 
40
 
41
 
42
- class LCache:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- Cache for storing intermediate states of linear attention layers.
45
- Supports a sliding window if max_length is set.
46
  """
47
 
48
- def __init__(self):
49
- """
50
- Initialize the cache.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- Args:
53
- max_length (Optional[int]): Maximum number of tokens to keep per layer (if set).
54
- """
55
- self.states: List[Dict[str, torch.Tensor]] = []
56
- self.seen_tokens = 0
57
 
58
- def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
59
- """
60
- Retrieve the state for the given layer index, if it exists.
61
- """
62
- if layer_idx < len(self.states):
63
- return self.states[layer_idx]
64
- return None
 
 
65
 
66
- def update(self, layer_idx: int, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  """
68
- Update the cache for a given layer.
69
- If max_length is set, keep only the last max_length tokens in any sequence state.
70
  """
71
- detached_kwargs = {}
72
- for key, value in kwargs.items():
73
- if isinstance(value, torch.Tensor):
74
- value = value.detach()
75
- detached_kwargs[key] = value
76
-
77
- if len(self.states) <= layer_idx:
78
- self.states.append(detached_kwargs)
79
  else:
80
- self.states[layer_idx].update(detached_kwargs)
 
 
81
 
82
- def reset(self):
83
- """
84
- Reset the cache and token counter.
85
- """
86
- self.states.clear()
87
- self.seen_tokens = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  class LiZAttention(nn.Module):
@@ -94,33 +275,69 @@ class LiZAttention(nn.Module):
94
  self,
95
  base_attn: nn.Module,
96
  layer_idx: int,
97
- base_config, # Backbone Config
98
  linear_cache: Optional[LCache] = None,
99
  operator_mode: str = "delta_rule",
100
- max_self_attn_length: int = 2048,
 
 
 
101
  mag_weight: float = 0.5,
 
102
  max_chunk_size: int = 64,
 
 
 
 
 
103
  ):
104
  super().__init__()
105
- self.base_attn = base_attn
 
 
 
106
  self.base_config = base_config
107
  self.layer_idx = layer_idx
108
  self.max_self_attn_length = max_self_attn_length
 
109
  self.mag_weight = mag_weight
 
110
  self.max_chunk_size = max_chunk_size
111
- self.linear_cache = linear_cache or LCache()
 
 
 
112
  (
113
  self.num_heads,
114
  self.head_dim,
115
  self.num_key_value_heads,
116
  self.num_key_value_groups,
 
117
  ) = self._get_attention_parameters(base_attn, base_config)
118
- self.operator = get_attention_operator(operator_mode)
119
- self.pool_g = nn.AdaptiveAvgPool1d(
120
- output_size=self.head_dim * self.num_key_value_heads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
 
123
- def _get_attention_parameters(self, base_attn, base_config):
 
 
124
  """Retrieve the attention parameters from the base attention module."""
125
  # first order base attention module and second order config
126
  num_heads = (
@@ -129,8 +346,15 @@ class LiZAttention(nn.Module):
129
  or getattr(base_config, "num_heads", None)
130
  or getattr(base_config, "num_attention_heads", None)
131
  )
132
- head_dim = getattr(base_attn, "head_dim", None) or getattr(
133
- base_config, "head_dim", None
 
 
 
 
 
 
 
134
  )
135
  num_key_value_heads = (
136
  getattr(base_attn, "num_kv_heads", None)
@@ -141,14 +365,18 @@ class LiZAttention(nn.Module):
141
  num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
142
  num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
143
  )
 
144
  return (
145
  num_heads,
146
  head_dim,
147
  num_key_value_heads,
148
  num_key_value_groups,
 
149
  )
150
 
151
- def _apply_projections(self, hidden_states):
 
 
152
  base_attn = self.base_attn
153
  if hasattr(base_attn, "q_proj"):
154
  # LLama, OLMO and Mistral style
@@ -166,89 +394,51 @@ class LiZAttention(nn.Module):
166
  qkv = base_attn.c_attn(hidden_states)
167
  q, k, v = qkv.chunk(3, dim=-1)
168
  out_proj = base_attn.c_proj
 
 
 
 
 
 
169
  else:
170
  raise ValueError("Unsupported attention module: cannot find projections.")
171
  # Ensure stability
172
- q = torch.clamp(q, min=-1e4, max=1e4)
173
- k = torch.clamp(k, min=-1e4, max=1e4)
174
- v = torch.clamp(v, min=-1e4, max=1e4)
175
  return q, k, v, out_proj
176
 
177
- def _prepare_attn_input(self, q, k, v, gate_norm):
178
- # Gating for linear attn
179
- g = self.pool_g(k)
180
-
181
- # Reshape for multi-head
182
- q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
183
- k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
184
- v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
185
- g = rearrange(g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
186
-
187
- # Repeat for GQA
188
- k = repeat_kv(k, self.num_key_value_groups)
189
- v = repeat_kv(v, self.num_key_value_groups)
190
- g = repeat_kv(g, self.num_key_value_groups)
191
-
192
- ## linear part
193
- q = torch.clamp(F.softmax(q, dim=-1), min=1e-6, max=1 - 1e-6)
194
- k = torch.clamp(F.softmax(k, dim=-1), min=1e-6, max=1 - 1e-6)
195
-
196
- g = F.logsigmoid(g) / gate_norm
197
- g = torch.clamp(g, min=-gate_norm, max=gate_norm)
198
-
199
- # Convert to float32 for numerical stability and get model dtype
200
- q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g))
201
-
202
- return q, k, v, g
203
-
204
- def _process_linear_attn(self, q, k, v, g, out_proj, tensor_dtype, kwargs):
205
- # Retrieve recurrent state from cache (inference only)
206
- if kwargs["use_cache"]:
207
- last_state = self.linear_cache[self.layer_idx]
208
- recurrent_state = (
209
- last_state["recurrent_state"]
210
- if last_state is not None and "recurrent_state" in last_state
211
- else None
212
  )
213
- else:
214
- recurrent_state = None
215
-
216
- # Linear attention
217
- o_lin, recurrent_state = self.operator(
218
- q,
219
- k,
220
- v,
221
- beta=g,
222
- chunk_size=self.max_chunk_size,
223
- recurrent_state=recurrent_state,
224
- )
225
- o_lin = rearrange(o_lin, "b h n d -> b n (h d)").to(tensor_dtype)
226
- o_lin = out_proj(o_lin)
227
- # Ensure stability (o_lin = soft_clamp(o_lin) ?)
228
- o_lin = torch.clamp(o_lin, min=-1e4, max=1e4)
229
-
230
- # Save recurrent state
231
- if kwargs["use_cache"]:
232
- self.linear_cache.update(self.layer_idx, recurrent_state=recurrent_state)
233
- return o_lin
234
-
235
- def _process_self_attn(self, hidden_states, attention_mask, kwargs):
236
- # If cache_implementation="static" -> truncated attention
237
- hidden_states, attention_mask = truncate_attention_mask(
238
- hidden_states, attention_mask, self.max_self_attn_length
239
- )
240
-
241
- if kwargs.get("position_embeddings", None) is not None:
242
- cos, sin = kwargs["position_embeddings"]
243
- cos = cos[:, -self.max_self_attn_length :]
244
- sin = sin[:, -self.max_self_attn_length :]
245
- kwargs["position_embeddings"] = (cos, sin)
246
-
247
- if isinstance(kwargs.get("past_key_value", None), DynamicCache):
248
- # cache management
249
- if len(kwargs["past_key_value"]) > self.layer_idx and self.layer_idx == 0:
250
- kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # Standard attention (mask and rotation is applied inside)
253
  base_attn_outputs = self.base_attn(
254
  hidden_states,
@@ -271,15 +461,67 @@ class LiZAttention(nn.Module):
271
  o_base = base_attn_outputs
272
  attn_weights, present_key_value, expected_attn_mode = None, None, 1
273
  # Ensure stability
274
- o_base = torch.clamp(o_base, min=-1e4, max=1e4)
275
  return o_base, attn_weights, present_key_value, expected_attn_mode
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  def forward(
278
  self,
279
  hidden_states: torch.Tensor,
280
  attention_mask: Optional[torch.Tensor] = None,
281
  **kwargs,
282
- ):
 
283
  device = hidden_states.device
284
  tensor_dtype = hidden_states.dtype
285
  self.base_attn.to(device)
@@ -287,73 +529,155 @@ class LiZAttention(nn.Module):
287
  if self.training:
288
  kwargs.pop("past_key_value", None)
289
  kwargs["use_cache"] = False
290
- else:
291
- # Force evaluation
292
- kwargs["use_cache"] = True
293
 
294
  kwargs.pop("position_ids", None) # obsolete
295
 
296
- # Apply projections to hidden states
297
- q, k, v, out_proj = self._apply_projections(hidden_states)
298
-
299
- # Manage attention mask (with padding)
300
- if attention_mask is not None:
301
- # attention_mask -> [batch, seq], v: [batch, seq, ...]
302
- v = apply_linear_attention_mask(attention_mask, v)
303
 
304
- # Prepare inputs tensor for linear attn
305
- gate_norm = kwargs.get("gate_logit_normalizer", 16)
306
- q, k, v, g = self._prepare_attn_input(q, k, v, gate_norm)
307
-
308
- # Process linear attn from mask
309
- o_lin = self._process_linear_attn(q, k, v, g, out_proj, tensor_dtype, kwargs)
310
 
311
  # Process self attn with truncation
312
  o_base, attn_weights, present_key_value, expected_attn_mode = (
313
  self._process_self_attn(hidden_states, attention_mask, kwargs)
314
  )
315
 
316
- # Force cast typing
317
- o_lin = o_lin.to(tensor_dtype)
318
- o_base = o_base.to(tensor_dtype)
319
 
320
- # Apply Memory as Gate in self-attention (with max length management)
321
- if o_lin.shape[1] > o_base.shape[1]:
322
- o_padding = torch.zeros_like(o_lin).to(tensor_dtype)
323
- o_padding[:, -o_base.shape[1] :] = o_base
324
- o_base = o_padding # Left PAD mask
325
- elif o_lin.shape[1] != o_base.shape[1]: # Abnormality
326
- left_trunc = min(o_lin.shape[1], o_base.shape[1])
327
- o_lin, o_base = o_lin[:, -left_trunc:], o_base[:, -left_trunc:]
328
- out = self.mag_weight * o_lin + (1 - self.mag_weight) * o_base
329
- # Ensure stability
330
- out = torch.clamp(out, min=-1e4, max=1e4)
331
 
332
  # Return output following transformer convention
333
  if expected_attn_mode == 3:
334
  return out, attn_weights, present_key_value
335
- elif expected_attn_mode == 2:
336
  return out, attn_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  else:
338
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
 
341
  def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
342
  model: nn.Module,
343
  base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
344
- liza_attention: LiZAttention,
345
- target_modules: list,
346
  linear_cache: Optional[LCache] = None,
 
 
347
  operator_mode: str = "delta_rule",
 
 
 
348
  mag_weight: float = 0.5,
 
349
  max_chunk_size: int = 64,
350
- max_self_attn_length: int = 2048,
351
- ):
 
 
 
 
 
352
  """Replace target modules in a model with LiZAttention."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  linear_cache = linear_cache or LCache()
354
  # Inject LiZAttention into the model
355
  for name, _ in model.named_modules():
356
- if name in target_modules:
357
  parent = model
358
  *path, last = name.split(".")
359
  for p in path:
@@ -368,14 +692,45 @@ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-a
368
  base_config=base_config,
369
  linear_cache=linear_cache,
370
  operator_mode=operator_mode,
 
 
371
  max_self_attn_length=max_self_attn_length,
 
372
  mag_weight=mag_weight,
 
373
  max_chunk_size=max_chunk_size,
 
 
 
 
374
  ),
375
  )
376
  return model, linear_cache
377
 
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  class TpttModel(PreTrainedModel):
380
  """
381
  TPTT model wrapper with linear attention (LiZA) and LoRA support.
@@ -396,312 +751,562 @@ class TpttModel(PreTrainedModel):
396
  super().__init__(config, **kwargs)
397
  repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
398
 
399
- # 1. Load backbone TODO : support no model.safetensors
400
- self.backbone = AutoModelForCausalLM.from_pretrained(
401
- config.base_model_name, **kwargs
402
- )
403
- self._retie_lm_after_load(**kwargs) # Force lm tie weights
 
 
 
 
 
 
 
 
404
 
405
  # 2. Inject LiZA attention
406
  self.linear_cache = LCache()
407
- self.backbone, self.linear_cache = self.inject_liza_attention(
408
- self.backbone, config, self.linear_cache
409
  )
410
- # 3. Apply LoRA if present and configured
 
411
  if config.lora_config is not None:
412
  lora_config_obj = LoraConfig(**config.lora_config)
413
- self.backbone = get_peft_model(self.backbone, lora_config_obj)
414
- if repo_or_path:
415
- self.load_peft_safetensors(
416
- repo_or_path, token=kwargs.get("token", None)
417
- )
418
-
419
- def load_peft_safetensors(self, src, token=None):
420
- # src: local dir or repo_id
421
- fname = "adapter_model.safetensors"
422
- if os.path.isdir(src):
423
- path = os.path.join(src, fname)
424
- if not os.path.exists(path):
425
- return
426
  else:
427
- if fname not in list_repo_files(src, token=token):
428
- return
429
- path = hf_hub_download(src, fname, token=token)
430
- with safe_open(path, framework="pt") as f:
431
- self.backbone.load_state_dict(
432
- {k: f.get_tensor(k) for k in f.keys()}, strict=False
 
 
 
 
433
  )
 
434
 
435
- @staticmethod
436
- def inject_liza_attention(
437
- backbone,
438
- config,
439
- linear_cache,
 
440
  ):
441
- """
442
- Inject LiZAttention into the specified target modules of the base model.
443
- """
444
- # Find target modules by suffix (e.g., "attn", "attention")
445
- target_modules = [
446
- name
447
- for name, _ in backbone.named_modules()
448
- if any(name.endswith(suffix) for suffix in config.target_modules_names)
449
- ]
450
- if not target_modules:
451
- raise ValueError(
452
- f"Target modules '{config.target_modules_names}' not found in the model."
453
- )
454
- # Inject LiZAttention (external function, not shown here)
455
- return get_tptt_model(
456
- backbone,
457
- base_config=backbone.config,
458
- liza_attention=LiZAttention,
459
- target_modules=target_modules,
460
- linear_cache=linear_cache,
461
- operator_mode=config.operator_mode,
462
- max_self_attn_length=config.max_self_attn_length,
463
- mag_weight=config.mag_weight,
464
- max_chunk_size=config.max_chunk_size,
465
- )
466
-
467
- def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
468
- """
469
- Forward pass. All arguments are passed to the underlying base model.
470
- """
471
  if self.training:
472
  kwargs["use_cache"] = False
473
  kwargs.pop("num_items_in_batch", None)
474
- else:
475
- kwargs["use_cache"] = True
476
- return self.backbone(
 
477
  input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
478
  )
479
 
480
  def generate(self, *args, **kwargs):
481
- # Delegate the generate call to the backbone model, which supports generation
482
- return self.backbone.generate(*args, **kwargs)
483
 
484
  def save_pretrained(self, path: str, **kwargs):
485
  """Save model weights, config, and source code to the given path."""
486
- super().save_pretrained(path, **kwargs)
487
-
488
- # 1. Save PEFT weights and clean adapter config
489
- self._save_peft_weights(path, **kwargs)
 
490
  # 2. Copy Python files for trust_remote_code
491
- self._copy_source_files(path)
492
-
493
- def _save_peft_weights(self, path: str, **kwargs):
494
- """Save PEFT weights and remove redundant adapter config."""
495
- self.backbone.save_pretrained(path, **kwargs)
496
- adapter_config_path = os.path.join(path, "adapter_config.json")
497
- if os.path.exists(adapter_config_path):
498
- os.remove(adapter_config_path)
499
-
500
- def _copy_source_files(self, path: str):
 
 
 
 
501
  """Copy all .py files from package directory for trust_remote_code."""
502
  src_dir = os.path.dirname(os.path.abspath(__file__))
 
 
 
 
 
503
  for fname in os.listdir(src_dir):
504
  if fname.endswith(".py"):
505
  src = os.path.join(src_dir, fname)
506
- dst = os.path.join(path, fname)
507
  shutil.copy2(src, dst)
508
 
509
- def _retie_lm_after_load(self, **kwargs):
510
  """Re-link lm_head after loading external weights."""
511
- embed_lm = find_embedding_lm(self.backbone)
512
- if embed_lm is not None and hasattr(self.backbone, "lm_head"):
513
- if self.backbone.lm_head is None: # ensure lm_head exists
514
- self.backbone.lm_head = nn.Linear(
515
  embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
516
  )
517
  if kwargs.get("tie_word_embeddings", True):
518
- self.backbone.lm_head.weight = embed_lm.weight # share weights
519
  logger.info("Weights of lm_head have been shared with embedding.")
520
  else:
521
- self.backbone.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
522
  logger.info("Weights of lm_head have been cloned from the embedding.")
523
 
524
  @classmethod
525
- def from_pretrained(cls, *args, **kwargs):
526
- model = super().from_pretrained(*args, **kwargs)
527
- model._retie_lm_after_load(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
528
  return model
529
 
530
 
531
  TpttModel.register_for_auto_class("AutoModelForCausalLM")
532
 
533
 
534
- class AttentionOperator(nn.Module):
535
  """Base class for linear attention operators."""
536
 
537
- def __init__(self, mode="delta_rule"):
 
 
 
 
 
 
 
 
 
538
  super().__init__()
539
- self.mode = mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
541
- def forward(self, q, k, v, **options):
542
- """Forward pass for the attention operator."""
543
- beta = options.get("beta", None)
544
- chunk_size = options.get("chunk_size", 64)
545
- scale = options.get("scale", 1)
546
- recurrent_state = options.get("recurrent_state", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
- if self.mode == "delta_rule":
549
- return self.chunk_delta_rule_forward(
550
- q, k, v, beta, chunk_size, initial_state=recurrent_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  )
552
- if self.mode == "gla":
553
- return self.gla_forward(q, k, v, beta, scale, initial_state=recurrent_state)
554
- raise ValueError(f"Unknown operator mode: {self.mode}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
 
556
  @staticmethod
557
- def chunk_delta_rule_forward(
558
- query, key, value, beta, chunk_size, initial_state=None
559
- ):
 
 
 
 
 
 
 
 
 
 
560
  """
561
- Implementation of https://arxiv.org/abs/2406.06484
562
- query, key, value, beta: [batch, num_heads, seq_len, head_dim]
563
- chunk_size: int
564
- initial_state: [batch, num_heads, head_dim, head_dim] or None
565
  """
 
 
 
566
  batch_size, num_heads, seq_len, head_dim = query.shape
567
  chunk_size = get_valid_chunk_size(seq_len, chunk_size)
568
  num_chunks = seq_len // chunk_size
569
 
570
- # Reshape for chunking: [batch, num_heads, num_chunks, chunk_size, head_dim]
571
- q_chunks = query.reshape(
572
- batch_size, num_heads, num_chunks, chunk_size, head_dim
573
- )
574
- k_chunks = key.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
575
- v_chunks = value.reshape(
576
- batch_size, num_heads, num_chunks, chunk_size, head_dim
577
- )
578
- beta_chunks = beta.reshape(
579
- batch_size, num_heads, num_chunks, chunk_size, head_dim
580
- )
581
 
582
- # Output buffer
583
- output = torch.empty_like(q_chunks)
584
- # State: [batch, num_heads, head_dim, head_dim]
585
- expect_state_shape = (batch_size, num_heads, head_dim, head_dim)
586
- if initial_state is not None and initial_state.shape == expect_state_shape:
587
- # Use provided initial state
588
- state = initial_state.to(device=query.device, dtype=query.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  else:
590
- state = torch.zeros(
591
- batch_size,
592
- num_heads,
593
- head_dim,
594
- head_dim,
595
  device=query.device,
596
- dtype=query.dtype,
597
  )
598
 
599
- def process_chunk(q, k, v, b, state):
600
- """
601
- q, k, v, b: [batch, num_heads, chunk_size, head_dim]
602
- state: [batch, num_heads, head_dim, head_dim]
603
- Returns: (output_chunk, new_state)
604
- """
605
- # Clamp to avoid numerical instabilities (not in paper)
606
- k = torch.clamp(k, min=-1e4, max=1e4)
607
- v = torch.clamp(v, min=-1e4, max=1e4)
608
- b = torch.clamp(b, min=1e-6, max=1e4)
609
- q = torch.clamp(q, min=-1e4, max=1e4)
610
-
611
- # Eq. (10): β_t * k_t and β_t * v_t
612
- k_beta = k * b
613
- v_beta = v * b
614
-
615
- # Eq. (11): Lower-triangular matrix T (with -KβK^T off-diagonal, 1 on diagonal)
616
- # T = I - tril(KβK^T, -1)
617
- t_matrix = -(k_beta @ k.transpose(-2, -1)).tril(-1)
618
- t_matrix = torch.clamp(t_matrix, min=-1e4, max=1e4)
619
- t_matrix = t_matrix + torch.eye(
620
- q.shape[-2], device=q.device, dtype=q.dtype
621
- ).unsqueeze(0).unsqueeze(0)
622
-
623
- # Eq. (11): W = T Kβ, U = T Vβ
624
- w_matrix = t_matrix @ k_beta
625
- w_matrix = torch.clamp(w_matrix, min=-1e4, max=1e4)
626
-
627
- u_matrix = t_matrix @ v_beta
628
- u_matrix = torch.clamp(u_matrix, min=-1e4, max=1e4)
629
-
630
- # Eq. (12): u_i = U - W S (S = state)
631
- u_i = u_matrix - torch.matmul(w_matrix, state)
632
-
633
- # Eq. (12): inter-chunk output: q S
634
- o_inter = torch.matmul(q, state)
635
-
636
- # Eq. (12): intra-chunk attention: tril(q K^T)
637
- a_i = (q @ k.transpose(-2, -1)).tril()
638
-
639
- # Eq. (12): intra-chunk output: a_i u_i
640
- o_intra = torch.matmul(a_i, u_i)
641
-
642
- # Eq. (12): state update: S_new = S + K^T u_i
643
- new_state = state + torch.matmul(k.transpose(-2, -1), u_i)
644
- new_state = torch.clamp(new_state, min=-1e4, max=1e4)
645
-
646
- # Eq. (12): output = intra + inter
647
- return o_intra + o_inter, new_state
648
-
649
- for chunk_idx in range(num_chunks):
650
- q = q_chunks[:, :, chunk_idx]
651
- k = k_chunks[:, :, chunk_idx]
652
- v = v_chunks[:, :, chunk_idx]
653
- b = beta_chunks[:, :, chunk_idx]
654
-
655
- chunk_out, state = process_chunk(q, k, v, b, state)
656
- output[:, :, chunk_idx] = chunk_out
657
-
658
- # Reshape back to [batch, num_heads, seq_len, head_dim]
659
  output = output.reshape(batch_size, num_heads, seq_len, head_dim)
660
- return output, state
661
 
662
- @staticmethod
663
- def gla_forward(q, k, v, beta, scale, initial_state=None):
664
- """Forward pass for GLA attention operator."""
665
- if fused_chunk_gla is None or fused_recurrent_gla is None:
666
- raise RuntimeError("GLA kernels are not available: CUDA required.")
667
- if q.shape[-2] > 1:
668
- # Training or sequence length > 1
669
- return fused_chunk_gla(
670
- q,
671
- k,
672
- v,
673
- beta,
674
- scale=scale,
675
- initial_state=initial_state,
676
- output_final_state=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
  )
678
- return fused_recurrent_gla(
679
- q,
680
- k,
681
- v,
682
- beta,
683
- scale=scale,
684
- initial_state=initial_state,
685
- output_final_state=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
  )
687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
- def get_attention_operator(mode):
690
- """Factory for AttentionOperator."""
691
- return AttentionOperator(mode=mode)
 
 
 
 
692
 
693
 
694
  def extract_layer_idx(module_name: str) -> int:
695
- """
696
- Extract the layer index from a module name string.
697
- """
698
  match = re.search(r"\.(\d+)\.", module_name)
699
  if match:
700
  return int(match.group(1))
701
  return -1
702
 
703
 
704
- def find_embedding_lm(module):
705
  """Find the embedding weight in a model module."""
706
  for _, child in module.named_modules():
707
  if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
@@ -713,40 +1318,51 @@ def find_embedding_lm(module):
713
  return None
714
 
715
 
716
- def soft_clamp(x, min_val=-1e4, max_val=1e4):
717
- """Differentiable clamping for stability"""
718
- scale = (max_val - min_val) / 2
719
- center = (max_val + min_val) / 2
720
- return torch.tanh((x - center) / scale) * scale + center
721
-
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
- def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
724
- """Repeat key/value heads for grouped query attention (GQA)."""
725
- return x.repeat_interleave(n_rep, dim=1)
726
 
 
 
 
727
 
728
- def split_qkv(base_attn, qkv):
729
- """Split the QKV tensor into separate Q, K, and V tensors."""
730
- num_q_heads = getattr(base_attn, "num_q_heads", None)
731
- num_k_heads = getattr(base_attn, "num_k_heads", None)
732
- num_v_heads = getattr(base_attn, "num_v_heads", None)
733
- head_dim = getattr(base_attn, "head_dim", None)
734
-
735
- q_len = num_q_heads * head_dim
736
- k_len = num_k_heads * head_dim
737
- v_len = num_v_heads * head_dim
738
 
739
- q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
740
- return q, k, v
 
 
 
 
 
 
 
741
 
742
 
743
- def apply_linear_attention_mask(attention_mask, v):
744
- # extract (if) padding mask
 
 
745
  if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
746
- # [batch, 1, seq, seq] -> [batch, seq]
747
  mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
748
  else:
749
- # Squeeze all singleton dims except batch (dim=0)
750
  mask = attention_mask.squeeze(
751
  dim=tuple(
752
  i
@@ -754,16 +1370,30 @@ def apply_linear_attention_mask(attention_mask, v):
754
  if attention_mask.shape[i] == 1
755
  )
756
  )
757
- # handle left padding : mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
758
- mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  return v * mask
760
 
761
 
762
- def truncate_attention_mask(hidden_states, attention_mask, max_length):
763
- """
764
- Truncate hidden_states and attention_mask to the last window of size max_length,
765
- matching the sequence dimension of hidden_states.
766
- """
767
  seq_dim = 1 # convention: (batch, seq, ...)
768
  seq_len = hidden_states.shape[seq_dim]
769
  if seq_len > max_length:
@@ -785,22 +1415,58 @@ def truncate_attention_mask(hidden_states, attention_mask, max_length):
785
  return hidden_states, attention_mask
786
 
787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
789
- """
790
- Return the largest chunk_size <= chunk_size that divides total_l.
791
- If no chunk_size > 1 fits, return 1.
792
- """
793
  for c in range(min(chunk_size, total_l), 0, -1):
794
  if total_l % c == 0:
795
  return c
796
  return 1
797
 
798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
  def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
800
- """
801
- Match the size of tensor x along dimension dim to target_size by interpolation
802
- or projection.
803
- """
804
  src_size = x.shape[dim]
805
  if src_size == target_size:
806
  return x
@@ -815,3 +1481,24 @@ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
815
  x = F.linear(x, eye) # pylint: disable=not-callable
816
  x = torch.moveaxis(x, -1, dim)
817
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
  """
4
  This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
  Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
  """
8
 
9
  import logging
10
+ import math
11
  import os
12
+ from pathlib import Path
13
  import re
14
  import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
 
18
  import torch
19
  import torch.nn.functional as F
20
  from einops import rearrange
21
  from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
  from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
  from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import (
28
+ AutoConfig,
29
+ AutoModel,
30
+ AutoModelForCausalLM,
31
+ DynamicCache,
32
+ PreTrainedModel,
33
+ )
34
  from transformers.configuration_utils import PretrainedConfig
35
 
36
  from .configuration_tptt import TpttConfig
37
 
38
+ logger = logging.getLogger(__name__) # monitoring
39
 
 
 
 
 
 
40
 
41
+ class LCache:
42
+ """Cache for storing intermediate states of linear attention layers."""
43
+
44
+ def __init__(self):
45
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
46
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
47
+ {}
48
+ ) # recurrent states and qkv buffers
49
 
50
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
51
+ """Retrieve cached state for a given layer, or None if not present"""
52
+ return self.inputs_states.get(layer_idx, None)
53
 
54
+ def update(self, layer_idx: int, **kwargs):
55
+ """Detach all tensors to avoid retaining computation graphs"""
56
+ detached_kwargs = {
57
+ k: v.detach() if isinstance(v, torch.Tensor) else v
58
+ for k, v in kwargs.items()
59
+ }
60
+ # Update or create the state for the specified layer
61
+ if layer_idx in self.inputs_states:
62
+ self.inputs_states[layer_idx].update(detached_kwargs)
63
+ else:
64
+ self.inputs_states[layer_idx] = detached_kwargs
65
 
66
+ def reset(self):
67
+ """Clear all cached states and reset the token counter"""
68
+ self.inputs_states.clear()
69
 
70
 
71
+ class CausalAvgPool1d(nn.Module):
72
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
73
+
74
+ def __init__(
75
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
76
+ ):
77
+ super().__init__()
78
+ self.offsets = offsets
79
+ self.mode = mode
80
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """x: [B, S, F] → [B, S, F → output_size]"""
84
+ x_ = x.transpose(1, 2) # [B, F, S]
85
+ idxs = torch.tensor(self.offsets, device=x.device)
86
+ ksize = idxs.max() - idxs.min() + 1
87
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
88
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
89
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
90
+ pad_left = -idxs.min().item()
91
+ pad_right = (ksize - 1) - pad_left
92
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
93
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
94
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
95
+
96
+
97
+ class LinearAttention(nn.Module):
98
  """
99
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
100
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
101
  """
102
 
103
+ def __init__(
104
+ self,
105
+ hidden_dim: int,
106
+ num_heads: int,
107
+ head_dim: Optional[int] = None,
108
+ num_key_value_heads: Optional[int] = None,
109
+ num_key_value_groups: Optional[int] = None,
110
+ bias: bool = True,
111
+ dropout: Optional[float] = None,
112
+ linear_precision: torch.dtype = torch.float32,
113
+ padding_side: str = "right",
114
+ shared_attn: bool = False, # shared attention
115
+ layer_idx: int = 0,
116
+ operator_mode: str = "delta_rule",
117
+ use_linear_checkpoint: bool = False,
118
+ recurrent_config: Optional[Dict[str, Any]] = None,
119
+ linear_cache: Optional[LCache] = None,
120
+ max_chunk_size: int = 64,
121
+ bidirectional: bool = False, # not used if causal
122
+ pooling_config: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ super().__init__()
125
+ if pooling_config is None:
126
+ pooling_config = {
127
+ "offsets": (0, 1, 2),
128
+ "mode": "replicate",
129
+ }
130
+ self.hidden_dim = hidden_dim
131
+ self.num_heads = num_heads
132
+ self.head_dim = head_dim or hidden_dim // num_heads
133
+ self.num_key_value_heads = num_key_value_heads or num_heads
134
+ self.num_key_value_groups = num_key_value_groups or (
135
+ num_heads // (num_key_value_heads or num_heads)
136
+ )
137
+ self.scaling = self.head_dim**-0.5
138
+ self.linear_precision = linear_precision
139
+ self.padding_side = padding_side
140
 
141
+ self.shared_attn = shared_attn
 
 
 
 
142
 
143
+ if not shared_attn:
144
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
145
+ self.k_proj = nn.Linear(
146
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
147
+ )
148
+ self.v_proj = nn.Linear(
149
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
150
+ )
151
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
152
 
153
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
154
+
155
+ self.linear_operator = LinearAttentionOp(
156
+ layer_idx=layer_idx,
157
+ operator_mode=operator_mode,
158
+ use_linear_checkpoint=use_linear_checkpoint,
159
+ recurrent_config=recurrent_config,
160
+ max_chunk_size=max_chunk_size,
161
+ linear_cache=linear_cache,
162
+ linear_precision=linear_precision,
163
+ )
164
+ self.bidirectional = bidirectional
165
+ # Causal average pooling for gating
166
+ self.pooling_config = pooling_config
167
+ self.pool_g = CausalAvgPool1d(
168
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ x: Union[List[torch.Tensor], torch.Tensor],
174
+ attn_mask: Optional[torch.Tensor] = None,
175
+ out_proj: Optional[nn.Module] = None,
176
+ **kwargs: Any,
177
+ ) -> torch.Tensor:
178
  """
179
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
 
180
  """
181
+
182
+ if not self.shared_attn:
183
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
184
+ # Projections
185
+ q = self.q_proj(hidden_states)
186
+ k = self.k_proj(hidden_states)
187
+ v = self.v_proj(hidden_states)
188
+ out_proj = self.out_proj
189
  else:
190
+ # Shared attention <=> no projections here
191
+ q, k, v = x[0], x[1], x[2]
192
+ out_proj = self.out_proj if out_proj is None else out_proj
193
 
194
+ # get dtype and device
195
+ final_dtype, final_device = q.dtype, q.device
196
+ # Masking if needed
197
+ if attn_mask is not None:
198
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
199
+
200
+ # Forget and Write Gating for linear attn (abusive term)
201
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
202
+
203
+ # Reshape for multi-head
204
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
205
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
206
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
207
+
208
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
209
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
210
+
211
+ # Repeat for GQA
212
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
213
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
214
+
215
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
216
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
217
+
218
+ ## DeltaNet-style: Silu activation and normalization
219
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
220
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
221
+
222
+ ## linear stability part
223
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
224
+
225
+ # Apply sigmoid to forget and write gates
226
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
227
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
228
+
229
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
230
+ q, k, v, f_g, w_g = (
231
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
232
+ )
233
+ g = (f_g, w_g)
234
+
235
+ # Linear Attention Core, output: [B, H, S, d]
236
+ if self.bidirectional: # Work only with uncausal attention
237
+ # Forward direction
238
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
239
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
240
+ kwargs_bwd = kwargs.copy()
241
+ kwargs_bwd["use_cache"] = False
242
+ out_backward = self.linear_operator(
243
+ torch.flip(q, dims=[2]),
244
+ torch.flip(k, dims=[2]),
245
+ torch.flip(v, dims=[2]),
246
+ tuple(torch.flip(t, dims=[2]) for t in g),
247
+ **kwargs_bwd,
248
+ )
249
+ # Flip the output back to restore proper order
250
+ out_backward = torch.flip(out_backward, dims=[2])
251
+ # Fusion: here, simple addition
252
+ out = out_forward + out_backward
253
+ else:
254
+ out = self.linear_operator(q, k, v, g, **kwargs)
255
+
256
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
257
+ out = rearrange(out, "b h s d -> b s (h d)")
258
+ # Normalize output (RMS norm). Note: bidirectional compatibility
259
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
260
+ # Ensure dtype and device consistency
261
+ out = out.to(dtype=final_dtype, device=final_device)
262
+ # Apply output projection
263
+ out = out_proj(out) # [B, S, D]
264
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
265
+ # Apply dropout if specified
266
+ if self.dropout is not None:
267
+ out = self.dropout(out)
268
+ return out
269
 
270
 
271
  class LiZAttention(nn.Module):
 
275
  self,
276
  base_attn: nn.Module,
277
  layer_idx: int,
278
+ base_config: PretrainedConfig, # Backbone Config
279
  linear_cache: Optional[LCache] = None,
280
  operator_mode: str = "delta_rule",
281
+ use_linear_checkpoint: bool = False,
282
+ recurrent_config: Optional[Dict[str, Any]] = None,
283
+ max_self_attn_length: Optional[int] = None, # unnecessary
284
+ base_scale_attn: bool = False,
285
  mag_weight: float = 0.5,
286
+ cross_gate: bool = False,
287
  max_chunk_size: int = 64,
288
+ linear_precision: Union[str, torch.dtype] = "float32",
289
+ padding_side: str = "right", # for tokenizer
290
+ disable_linear_attn: bool = False,
291
+ bidirectional: bool = False, # if True, use bidirectional attention
292
+ pooling_config: Optional[Dict[str, Any]] = None,
293
  ):
294
  super().__init__()
295
+ if isinstance(linear_precision, str):
296
+ linear_precision = getattr(torch, linear_precision)
297
+ self.linear_precision = linear_precision
298
+ self.base_attn: nn.Module = base_attn
299
  self.base_config = base_config
300
  self.layer_idx = layer_idx
301
  self.max_self_attn_length = max_self_attn_length
302
+ self.base_scale_attn = base_scale_attn
303
  self.mag_weight = mag_weight
304
+ self.cross_gate = cross_gate
305
  self.max_chunk_size = max_chunk_size
306
+ self.linear_precision = linear_precision
307
+ self.padding_side = padding_side
308
+ self.disable_linear_attn = disable_linear_attn
309
+
310
  (
311
  self.num_heads,
312
  self.head_dim,
313
  self.num_key_value_heads,
314
  self.num_key_value_groups,
315
+ self.hidden_dim,
316
  ) = self._get_attention_parameters(base_attn, base_config)
317
+ self.scaling = self.head_dim**-0.5
318
+
319
+ self.linear_attn = LinearAttention(
320
+ layer_idx=layer_idx,
321
+ shared_attn=True,
322
+ operator_mode=operator_mode,
323
+ use_linear_checkpoint=use_linear_checkpoint,
324
+ recurrent_config=recurrent_config,
325
+ hidden_dim=self.hidden_dim,
326
+ num_heads=self.num_heads,
327
+ head_dim=self.head_dim,
328
+ num_key_value_heads=self.num_key_value_heads,
329
+ num_key_value_groups=self.num_key_value_groups,
330
+ linear_precision=linear_precision,
331
+ linear_cache=linear_cache,
332
+ max_chunk_size=max_chunk_size,
333
+ padding_side=padding_side,
334
+ bidirectional=bidirectional,
335
+ pooling_config=pooling_config,
336
  )
337
 
338
+ def _get_attention_parameters(
339
+ self, base_attn: nn.Module, base_config: PretrainedConfig
340
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
341
  """Retrieve the attention parameters from the base attention module."""
342
  # first order base attention module and second order config
343
  num_heads = (
 
346
  or getattr(base_config, "num_heads", None)
347
  or getattr(base_config, "num_attention_heads", None)
348
  )
349
+ head_dim = (
350
+ getattr(base_attn, "head_dim", None)
351
+ or getattr(base_attn, "attention_head_size", None)
352
+ or getattr(base_config, "head_dim", None)
353
+ or (
354
+ getattr(base_config, "hidden_size", None) // num_heads
355
+ if num_heads and getattr(base_config, "hidden_size", None)
356
+ else None
357
+ )
358
  )
359
  num_key_value_heads = (
360
  getattr(base_attn, "num_kv_heads", None)
 
365
  num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
366
  num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
367
  )
368
+ hidden_dim = getattr(base_config, "hidden_size", None) or head_dim * num_heads
369
  return (
370
  num_heads,
371
  head_dim,
372
  num_key_value_heads,
373
  num_key_value_groups,
374
+ hidden_dim,
375
  )
376
 
377
+ def _apply_shared_projections(
378
+ self, hidden_states: torch.Tensor
379
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
380
  base_attn = self.base_attn
381
  if hasattr(base_attn, "q_proj"):
382
  # LLama, OLMO and Mistral style
 
394
  qkv = base_attn.c_attn(hidden_states)
395
  q, k, v = qkv.chunk(3, dim=-1)
396
  out_proj = base_attn.c_proj
397
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
398
+ # BERT - ViT
399
+ q = base_attn.query(hidden_states)
400
+ k = base_attn.key(hidden_states)
401
+ v = base_attn.value(hidden_states)
402
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
403
  else:
404
  raise ValueError("Unsupported attention module: cannot find projections.")
405
  # Ensure stability
406
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
407
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
408
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
409
  return q, k, v, out_proj
410
 
411
+ def _process_self_attn(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ attention_mask: Optional[torch.Tensor],
415
+ kwargs,
416
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
417
+ """Process the self-attention part (with truncation)."""
418
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
419
+ hidden_states, attention_mask = truncate_attention_mask(
420
+ hidden_states, attention_mask, self.max_self_attn_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
+ if kwargs.get("position_embeddings", None) is not None:
424
+ cos, sin = kwargs["position_embeddings"]
425
+ cos = cos[:, -self.max_self_attn_length :]
426
+ sin = sin[:, -self.max_self_attn_length :]
427
+ kwargs["position_embeddings"] = (cos, sin)
428
+
429
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
430
+ # cache management
431
+ if (
432
+ len(kwargs["past_key_value"]) > self.layer_idx
433
+ and self.layer_idx == 0
434
+ ):
435
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
436
+
437
+ # Ensure attention mask is of the correct dtype and device
438
+ if attention_mask is not None:
439
+ attention_mask = attention_mask.to(
440
+ dtype=hidden_states.dtype, device=hidden_states.device
441
+ )
442
  # Standard attention (mask and rotation is applied inside)
443
  base_attn_outputs = self.base_attn(
444
  hidden_states,
 
461
  o_base = base_attn_outputs
462
  attn_weights, present_key_value, expected_attn_mode = None, None, 1
463
  # Ensure stability
464
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
465
  return o_base, attn_weights, present_key_value, expected_attn_mode
466
 
467
+ def _prepare_attn_mixin(
468
+ self,
469
+ o_lin: torch.Tensor,
470
+ o_base: torch.Tensor,
471
+ tensor_dtype: torch.dtype,
472
+ eps: float = 1e-5,
473
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
474
+ """Prepare linear attn for mixing with self attn."""
475
+ # Force cast typing, shape : [b n (h d)]
476
+ o_lin = o_lin.to(tensor_dtype)
477
+ o_base = o_base.to(tensor_dtype)
478
+ # feature scaling
479
+ if self.base_scale_attn:
480
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
481
+ o_lin = scaler * o_lin
482
+ return o_lin, o_base
483
+
484
+ def _apply_mag(
485
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
486
+ ) -> torch.Tensor:
487
+ """Apply the MAG strategy"""
488
+ # Left-Padding management
489
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
490
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
491
+ linear_attention, softmax_attention = (
492
+ linear_attention[:, -left_trunc:],
493
+ softmax_attention[:, -left_trunc:],
494
+ )
495
+ # NAM : Neural Attention Mixer (with graph forcing)
496
+ mag_weight = torch.tensor(
497
+ self.mag_weight,
498
+ dtype=softmax_attention.dtype,
499
+ device=softmax_attention.device,
500
+ )
501
+ softmax_weighted = (1 - mag_weight) * softmax_attention
502
+ linear_weighted = mag_weight * linear_attention
503
+ if self.cross_gate:
504
+ output_attention = (
505
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
506
+ ) # complex cross product (unlinear interaction)
507
+ else:
508
+ output_attention = softmax_weighted + linear_weighted # classic
509
+
510
+ if torch.allclose(softmax_weighted, output_attention):
511
+ logger.info(
512
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
513
+ self.layer_idx,
514
+ )
515
+ # Final output
516
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
517
+
518
  def forward(
519
  self,
520
  hidden_states: torch.Tensor,
521
  attention_mask: Optional[torch.Tensor] = None,
522
  **kwargs,
523
+ ) -> torch.Tensor:
524
+ """Mix linear and self attention forward"""
525
  device = hidden_states.device
526
  tensor_dtype = hidden_states.dtype
527
  self.base_attn.to(device)
 
529
  if self.training:
530
  kwargs.pop("past_key_value", None)
531
  kwargs["use_cache"] = False
532
+ elif "use_cache" not in kwargs:
533
+ kwargs.pop("past_key_value", None)
534
+ kwargs["use_cache"] = False
535
 
536
  kwargs.pop("position_ids", None) # obsolete
537
 
538
+ # Apply shared projections
539
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
 
 
 
 
 
540
 
541
+ # Apply linear attention to hidden states
542
+ o_lin = self.linear_attn(
543
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
544
+ )
 
 
545
 
546
  # Process self attn with truncation
547
  o_base, attn_weights, present_key_value, expected_attn_mode = (
548
  self._process_self_attn(hidden_states, attention_mask, kwargs)
549
  )
550
 
551
+ # Prepare output mixing
552
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
 
553
 
554
+ # Apply Memory as Gate in self-attention (with length management and ablation)
555
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
 
 
 
 
 
 
 
 
 
556
 
557
  # Return output following transformer convention
558
  if expected_attn_mode == 3:
559
  return out, attn_weights, present_key_value
560
+ if expected_attn_mode == 2:
561
  return out, attn_weights
562
+ return out
563
+
564
+
565
+ def load_tptt_safetensors(
566
+ repo_or_path: str,
567
+ model: Union[PreTrainedModel, PeftModel],
568
+ subfolder: Optional[str] = None,
569
+ token: Optional[str] = None,
570
+ ) -> Union[PreTrainedModel, PeftModel]:
571
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
572
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
573
+ fname = "adapter_model.safetensors"
574
+ # subfolder management
575
+ if subfolder:
576
+ repo_or_path_norm = os.path.normpath(repo_or_path)
577
+ subfolder_norm = os.path.normpath(subfolder)
578
+ if not repo_or_path_norm.endswith(subfolder_norm):
579
+ fname = f"{subfolder}/{fname}" if subfolder else fname
580
+ # Find file path
581
+ if os.path.isdir(repo_or_path):
582
+ path = os.path.join(repo_or_path, fname)
583
+ if not os.path.exists(path):
584
+ return model
585
+ else:
586
+ if fname not in list_repo_files(repo_or_path, token=token):
587
+ return model
588
+ path = hf_hub_download(repo_or_path, fname, token=token)
589
+
590
+ # Load weights from safetensors
591
+ with safe_open(path, framework="pt") as f:
592
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
593
+
594
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
595
+ def adapt_keys(sd, model):
596
+ model_keys = list(model.state_dict().keys())
597
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
598
+ prefix = "tptt_model.base_model."
599
+ elif any(k.startswith("base_model.") for k in model_keys):
600
+ prefix = "base_model."
601
  else:
602
+ prefix = ""
603
+
604
+ has_base_attn = any(".base_attn." in k for k in model_keys)
605
+
606
+ def adapt_key(k):
607
+ k_ = k if k.startswith(prefix) else prefix + k
608
+ # first, verify and modify base_attn (LiZA)
609
+ if ".base_attn." in k_ and not has_base_attn:
610
+ k_ = k_.replace(".base_attn.", ".")
611
+ # change LoRA if needed
612
+ if (
613
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
614
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
615
+ k_ = k_.replace(".weight", ".default.weight")
616
+ return k_
617
+
618
+ return {adapt_key(k): v for k, v in sd.items()}
619
+
620
+ state_dict = adapt_keys(state_dict, model)
621
+
622
+ # Cast tensors to the expected dtype of the model parameters
623
+ model_state_dict = model.state_dict()
624
+ for k, v in state_dict.items():
625
+ if k in model_state_dict:
626
+ expected_dtype = model_state_dict[k].dtype
627
+ if v.dtype != expected_dtype:
628
+ state_dict[k] = v.to(expected_dtype)
629
+
630
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
631
+
632
+ # Load into model
633
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
634
+ missing_lora = [k for k in missing if "lora" in k]
635
+ if missing_lora:
636
+ logger.warning("Missing keys: %s", missing_lora)
637
+ if unexpected:
638
+ logger.warning("Unexpected keys: %s", unexpected)
639
+ return model
640
 
641
 
642
  def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
643
  model: nn.Module,
644
  base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
 
 
645
  linear_cache: Optional[LCache] = None,
646
+ liza_attention: nn.Module = LiZAttention,
647
+ target_modules_names: Optional[list[str]] = None,
648
  operator_mode: str = "delta_rule",
649
+ use_linear_checkpoint: bool = False,
650
+ recurrent_config: Optional[Dict[str, Any]] = None,
651
+ base_scale_attn: bool = False,
652
  mag_weight: float = 0.5,
653
+ cross_gate: bool = False,
654
  max_chunk_size: int = 64,
655
+ linear_precision: torch.dtype = torch.float32,
656
+ max_self_attn_length: Optional[int] = None, # unnecessary
657
+ padding_side: str = "right", # for tokenizer
658
+ bidirectional: bool = False, # if True, use bidirectional attention
659
+ pooling_config: Optional[Dict[str, Any]] = None,
660
+ **kwargs, # quickfix unexpected arguments
661
+ ) -> Tuple[PreTrainedModel, LCache]:
662
  """Replace target modules in a model with LiZAttention."""
663
+ if target_modules_names is None:
664
+ target_modules_names = ["attn", "self_attn", "attention"]
665
+ # Find target modules by suffix (e.g., "attn", "attention")
666
+ target_modules_names = [
667
+ name
668
+ for name, _ in model.named_modules()
669
+ if any(name.endswith(suffix) for suffix in target_modules_names)
670
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
671
+ ]
672
+ if not target_modules_names:
673
+ raise ValueError(
674
+ f"Target modules '{target_modules_names}' not found in the model."
675
+ )
676
+ # Prepare recurrent config
677
  linear_cache = linear_cache or LCache()
678
  # Inject LiZAttention into the model
679
  for name, _ in model.named_modules():
680
+ if name in target_modules_names:
681
  parent = model
682
  *path, last = name.split(".")
683
  for p in path:
 
692
  base_config=base_config,
693
  linear_cache=linear_cache,
694
  operator_mode=operator_mode,
695
+ use_linear_checkpoint=use_linear_checkpoint,
696
+ recurrent_config=recurrent_config,
697
  max_self_attn_length=max_self_attn_length,
698
+ base_scale_attn=base_scale_attn,
699
  mag_weight=mag_weight,
700
+ cross_gate=cross_gate,
701
  max_chunk_size=max_chunk_size,
702
+ linear_precision=linear_precision,
703
+ padding_side=padding_side,
704
+ bidirectional=bidirectional,
705
+ pooling_config=pooling_config,
706
  ),
707
  )
708
  return model, linear_cache
709
 
710
 
711
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
712
+ """Save trainable LoRA/Specific weights and adapting key names"""
713
+ # 1. Get the full state_dict
714
+ all_sd = model.state_dict()
715
+
716
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
717
+ trainable_keys = [
718
+ name for name, param in model.named_parameters() if param.requires_grad
719
+ ] # Also, you can manually select specific keys in model after load
720
+
721
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
722
+ to_save = {
723
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
724
+ for k in trainable_keys
725
+ }
726
+
727
+ # 4. Save the filtered adapters to a safetensors file
728
+ if to_save:
729
+ os.makedirs(os.path.dirname(path), exist_ok=True)
730
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
731
+ save_file(to_save, os.path.join(path, name))
732
+
733
+
734
  class TpttModel(PreTrainedModel):
735
  """
736
  TPTT model wrapper with linear attention (LiZA) and LoRA support.
 
751
  super().__init__(config, **kwargs)
752
  repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
753
 
754
+ # 1. Load backbone (with subfolder management) :
755
+ kwargs_bb = kwargs.copy()
756
+ if config.base_model_subfolder is not None:
757
+ kwargs_bb["subfolder"] = config.base_model_subfolder
758
+ else:
759
+ kwargs_bb.pop("subfolder", None)
760
+
761
+ if config.model_task == "causal_lm":
762
+ tptt_model = AutoModelForCausalLM.from_pretrained(
763
+ config.base_model_name, **kwargs_bb
764
+ )
765
+ else:
766
+ tptt_model = AutoModel.from_pretrained(config.base_model_name, **kwargs_bb)
767
 
768
  # 2. Inject LiZA attention
769
  self.linear_cache = LCache()
770
+ tptt_model, self.linear_cache = get_tptt_model(
771
+ tptt_model, config, self.linear_cache, **config.to_dict()
772
  )
773
+
774
+ # 3. Apply LoRA/Specific if present and configured
775
  if config.lora_config is not None:
776
  lora_config_obj = LoraConfig(**config.lora_config)
777
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
 
 
 
 
 
 
 
 
 
 
 
 
778
  else:
779
+ # Doesn't work if quantization is applied !
780
+ tptt_model = set_trainable_parameters(tptt_model)
781
+
782
+ # 4. Load safetensor if tptt/peft adaptor in repo
783
+ if repo_or_path:
784
+ tptt_model = load_tptt_safetensors(
785
+ repo_or_path,
786
+ tptt_model,
787
+ subfolder=kwargs.get("subfolder", None),
788
+ token=kwargs.get("token", None),
789
  )
790
+ self.tptt_model = tptt_model
791
 
792
+ def forward(
793
+ self,
794
+ input_ids: Optional[torch.LongTensor] = None,
795
+ attention_mask: Optional[torch.Tensor] = None,
796
+ labels: Optional[torch.LongTensor] = None,
797
+ **kwargs,
798
  ):
799
+ """Forward pass. All arguments are passed to the underlying base model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800
  if self.training:
801
  kwargs["use_cache"] = False
802
  kwargs.pop("num_items_in_batch", None)
803
+ elif "use_cache" not in kwargs: # evaluation
804
+ kwargs.pop("num_items_in_batch", None)
805
+ kwargs["use_cache"] = False
806
+ return self.tptt_model(
807
  input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
808
  )
809
 
810
  def generate(self, *args, **kwargs):
811
+ """Delegate the generate call to the backbone model, which supports generation"""
812
+ return self.tptt_model.generate(*args, **kwargs)
813
 
814
  def save_pretrained(self, path: str, **kwargs):
815
  """Save model weights, config, and source code to the given path."""
816
+ # 0. Save complete tptt config (with or without LoRA)
817
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
818
+ self._adjust_save_strategy(path, **kwargs)
819
+ # 1. Save true weights and adapte keys
820
+ save_tptt_safetensors(self, path)
821
  # 2. Copy Python files for trust_remote_code
822
+ self._copy_source_files(path, **kwargs)
823
+
824
+ def _adjust_save_strategy(self, path: str, **kwargs):
825
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
826
+ if isinstance(self.tptt_model, PeftModel):
827
+ self.tptt_model.save_pretrained(path, **kwargs)
828
+ safetensor_path = os.path.join(path, "model.safetensors")
829
+ if os.path.exists(safetensor_path):
830
+ os.remove(safetensor_path)
831
+ adapter_path = os.path.join(path, "adapter_config.json")
832
+ if os.path.exists(adapter_path):
833
+ os.remove(adapter_path)
834
+
835
+ def _copy_source_files(self, target_path: str, **kwargs):
836
  """Copy all .py files from package directory for trust_remote_code."""
837
  src_dir = os.path.dirname(os.path.abspath(__file__))
838
+ dst_dir = (
839
+ f"./{str(Path(target_path).parts[0])}"
840
+ if kwargs.get("subfolder", False)
841
+ else target_path
842
+ )
843
  for fname in os.listdir(src_dir):
844
  if fname.endswith(".py"):
845
  src = os.path.join(src_dir, fname)
846
+ dst = os.path.join(dst_dir, fname)
847
  shutil.copy2(src, dst)
848
 
849
+ def retie_lm_after_load(self, **kwargs):
850
  """Re-link lm_head after loading external weights."""
851
+ embed_lm = find_embedding_lm(self.tptt_model)
852
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
853
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
854
+ self.tptt_model.lm_head = nn.Linear(
855
  embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
856
  )
857
  if kwargs.get("tie_word_embeddings", True):
858
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
859
  logger.info("Weights of lm_head have been shared with embedding.")
860
  else:
861
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
862
  logger.info("Weights of lm_head have been cloned from the embedding.")
863
 
864
  @classmethod
865
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
866
+ """Custom from_pretrained that accepts the standard positional argument"""
867
+ config = kwargs.pop("config", None)
868
+ repo_or_path = (
869
+ pretrained_model_name_or_path
870
+ or kwargs.pop("pretrained_model_name_or_path", None)
871
+ or kwargs.pop("repo_or_path", None)
872
+ or (getattr(config, "_base_path", None) if config else None)
873
+ or (getattr(config, "_name_or_path", None) if config else None)
874
+ )
875
+
876
+ if config is None and repo_or_path is not None:
877
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
878
+ model = cls(config, *model_args, **kwargs)
879
+ model.retie_lm_after_load(**kwargs)
880
  return model
881
 
882
 
883
  TpttModel.register_for_auto_class("AutoModelForCausalLM")
884
 
885
 
886
+ class LinearAttentionOp(nn.Module):
887
  """Base class for linear attention operators."""
888
 
889
+ def __init__(
890
+ self,
891
+ layer_idx: int,
892
+ operator_mode: str = "delta_rule",
893
+ use_linear_checkpoint: bool = False,
894
+ recurrent_config: Optional[dict] = None,
895
+ max_chunk_size: int = 64,
896
+ linear_cache: Optional[LCache] = None,
897
+ linear_precision: torch.dtype = torch.float32,
898
+ ):
899
  super().__init__()
900
+ self.layer_idx = layer_idx
901
+ if recurrent_config is None:
902
+ operator_mode = "delta_rule" # force default operator mode if no config
903
+ recurrent_config = {
904
+ "order": 1,
905
+ "gate_type": "k",
906
+ "linear": True,
907
+ "trick": "derivative",
908
+ }
909
+ self.operator_mode = operator_mode
910
+ self.use_linear_checkpoint = use_linear_checkpoint
911
+
912
+ self.order = recurrent_config["order"]
913
+ self.gate_type = recurrent_config["gate_type"]
914
+ self.linear = recurrent_config["linear"]
915
+ self.trick = recurrent_config["trick"]
916
+
917
+ self.max_chunk_size = max_chunk_size
918
+ self.linear_cache = linear_cache or LCache()
919
+ self.linear_precision = linear_precision
920
+
921
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
922
+ """
923
+ Compute the gating tensor according to the gate_type.
924
+ """
925
+ if self.gate_type == "k":
926
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
927
+ if self.gate_type == "v":
928
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
929
+ if self.gate_type == "kv":
930
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
931
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
932
+
933
+ def get_cache(self, use_cache: bool) -> Tuple[
934
+ Optional[torch.Tensor],
935
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
936
+ ]:
937
+ """
938
+ Retrieve recurrent state and qkv buffers from the cache.
939
+ """
940
+ if not use_cache:
941
+ return None, None
942
+ last_state = self.linear_cache[self.layer_idx]
943
+ if last_state is not None:
944
+ recurrent_state = last_state.get("recurrent_state", None)
945
+ qkv_buffers = last_state.get("qkv", None)
946
+ else:
947
+ recurrent_state = None
948
+ qkv_buffers = None
949
+ return recurrent_state, qkv_buffers
950
 
951
+ def save_cache(
952
+ self,
953
+ use_cache: bool,
954
+ q: torch.Tensor,
955
+ k: torch.Tensor,
956
+ v: torch.Tensor,
957
+ gate: torch.Tensor,
958
+ state: torch.Tensor,
959
+ ) -> None:
960
+ """
961
+ Save the recurrent state and qkv buffers to the cache.
962
+ """
963
+ if not use_cache:
964
+ return
965
+ if self.order > 1:
966
+ qkv_buffers = (
967
+ q[:, :, -(self.order - 1) :, :],
968
+ k[:, :, -(self.order - 1) :, :],
969
+ v[:, :, -(self.order - 1) :, :],
970
+ gate[:, :, -(self.order - 1) :, :],
971
+ )
972
+ else:
973
+ qkv_buffers = None
974
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
975
 
976
+ def forward(
977
+ self,
978
+ q: torch.Tensor,
979
+ k: torch.Tensor,
980
+ v: torch.Tensor,
981
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
982
+ **kwargs,
983
+ ) -> torch.Tensor:
984
+ """
985
+ Forward pass for the attention operator.
986
+ """
987
+ # Ensure linear_precision for numerical stability (float32)
988
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
989
+ if isinstance(beta, (tuple, list)):
990
+ beta = tuple(b.to(self.linear_precision) for b in beta)
991
+ else:
992
+ beta = beta.to(self.linear_precision)
993
+
994
+ gate = self.compute_gate(beta)
995
+
996
+ # Retrieve cache if needed
997
+ use_cache = kwargs.get("use_cache", False)
998
+ use_checkpoint = not (use_cache) and self.use_linear_checkpoint
999
+ recurrent_state, qkvb = self.get_cache(use_cache)
1000
+
1001
+ if qkvb is not None and qkvb[0].shape == q.shape:
1002
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
1003
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
1004
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
1005
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
1006
+ self.linear_precision
1007
  )
1008
+
1009
+ output, state = self.chunk_delta_product_forward(
1010
+ q,
1011
+ k,
1012
+ v,
1013
+ gate,
1014
+ self.max_chunk_size,
1015
+ n=self.order,
1016
+ trick=self.trick,
1017
+ linear=self.linear,
1018
+ initial_state=recurrent_state,
1019
+ use_checkpoint=use_checkpoint,
1020
+ linear_precision=self.linear_precision,
1021
+ )
1022
+
1023
+ # Save cache if needed
1024
+ self.save_cache(use_cache, q, k, v, gate, state)
1025
+
1026
+ return output
1027
 
1028
  @staticmethod
1029
+ def chunk_delta_product_forward(
1030
+ query: torch.Tensor,
1031
+ key: torch.Tensor,
1032
+ value: torch.Tensor,
1033
+ beta_gate: torch.Tensor,
1034
+ chunk_size: int,
1035
+ n: int = 1,
1036
+ trick: str = "derivative",
1037
+ linear: bool = True,
1038
+ initial_state: Optional[torch.Tensor] = None,
1039
+ use_checkpoint: bool = True,
1040
+ linear_precision: torch.dtype = torch.float32,
1041
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1042
  """
1043
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1044
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
 
 
1045
  """
1046
+
1047
+ # --- Main chunk_delta_product_forward logic ---
1048
+
1049
  batch_size, num_heads, seq_len, head_dim = query.shape
1050
  chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1051
  num_chunks = seq_len // chunk_size
1052
 
1053
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1054
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1055
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1056
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
 
 
 
 
 
 
 
1057
 
1058
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1059
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1060
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1061
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1062
+
1063
+ k_beta = k_chunks * beta_chunks
1064
+ v_beta = v_chunks * beta_chunks
1065
+
1066
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1067
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1068
+
1069
+ # size : N = chunk_size * n
1070
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1071
+
1072
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1073
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1074
+
1075
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1076
+ if initial_state is not None and initial_state.shape == state_shape:
1077
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1078
  else:
1079
+ state = torch.full(
1080
+ state_shape,
1081
+ fill_value=1e-6, # stability if unlinear activation
 
 
1082
  device=query.device,
1083
+ dtype=linear_precision,
1084
  )
1085
 
1086
+ output, final_state = sequential_delta_product_scan(
1087
+ q_chunks.to(dtype=linear_precision),
1088
+ w.to(dtype=linear_precision),
1089
+ u.to(dtype=linear_precision),
1090
+ n,
1091
+ linear,
1092
+ chunk_size,
1093
+ state.to(dtype=linear_precision),
1094
+ linear_precision=linear_precision,
1095
+ use_checkpoint=use_checkpoint,
1096
+ )
1097
+
1098
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1099
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  output = output.reshape(batch_size, num_heads, seq_len, head_dim)
 
1101
 
1102
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1103
+
1104
+
1105
+ def sequential_delta_product_scan(
1106
+ q_chunks: torch.Tensor,
1107
+ w: torch.Tensor,
1108
+ u: torch.Tensor,
1109
+ n_orders: int,
1110
+ linear_activation: bool,
1111
+ current_chunk_size: int,
1112
+ initial_recurrent_state: torch.Tensor,
1113
+ linear_precision: torch.dtype,
1114
+ use_checkpoint: bool,
1115
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1116
+ """
1117
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1118
+ Implements the per-token Householder state updates.
1119
+ """
1120
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1121
+ output_inner = torch.empty_like(q_chunks)
1122
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1123
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1124
+
1125
+ def process_one_chunk(
1126
+ q_chunk_params: torch.Tensor,
1127
+ w_chunk_params: torch.Tensor,
1128
+ u_chunk_params: torch.Tensor,
1129
+ h_0_base: torch.Tensor,
1130
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1131
+ """
1132
+ Process a single chunk (with per-token state for n_orders > 1).
1133
+ """
1134
+ o_intra_current_chunk = torch.zeros(
1135
+ batch,
1136
+ head,
1137
+ chunk_n_total,
1138
+ dim,
1139
+ device=q_chunk_params.device,
1140
+ dtype=linear_precision,
1141
+ )
1142
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1143
+ current_accumulated_state_per_token = (
1144
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1145
+ ) # [B, H, current_chunk_size, D, D]
1146
+
1147
+ for step in range(n_orders):
1148
+ idx_virtual_tokens = (
1149
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1150
+ * n_orders
1151
+ + step
1152
  )
1153
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1154
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1155
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1156
+
1157
+ state_input_for_this_step = current_accumulated_state_per_token
1158
+
1159
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1160
+ k_trans_h_old = (
1161
+ torch.matmul(
1162
+ w_s.unsqueeze(-2),
1163
+ state_input_for_this_step,
1164
+ )
1165
+ .squeeze(-2)
1166
+ .to(dtype=linear_precision)
1167
+ )
1168
+
1169
+ u_val = u_s - k_trans_h_old
1170
+
1171
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1172
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1173
+ .squeeze(-2)
1174
+ .to(dtype=linear_precision)
1175
+ )
1176
+
1177
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1178
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1179
+ dtype=linear_precision
1180
+ )
1181
+
1182
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1183
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1184
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1185
+ dtype=linear_precision
1186
+ )
1187
+ # Return all needed for next chunk
1188
+ return (
1189
+ o_intra_current_chunk,
1190
+ o_inter_current_chunk,
1191
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1192
+ )
1193
+
1194
+ for chunk_idx_inner in range(num_chunks_inner):
1195
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1196
+ w_chunk_params = w[:, :, chunk_idx_inner]
1197
+ u_chunk_params = u[:, :, chunk_idx_inner]
1198
+
1199
+ # Checkpointed call if training
1200
+ call = (
1201
+ partial(checkpoint, use_reentrant=False)
1202
+ if use_checkpoint
1203
+ else lambda f, *a: f(*a)
1204
+ )
1205
+ o_intra, o_inter, h_0_base = call(
1206
+ process_one_chunk,
1207
+ q_chunk_params,
1208
+ w_chunk_params,
1209
+ u_chunk_params,
1210
+ h_0_base,
1211
+ )
1212
+ if not linear_activation: # unlinear activation between chunks
1213
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1214
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1215
+
1216
+ return output_inner, h_0_base
1217
+
1218
+
1219
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1220
+ """Unlinear activation between chunk"""
1221
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1222
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1223
+ return (x / scale) * x_gelu
1224
+
1225
+
1226
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1227
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1228
+ batch_size, num_heads, _, head_dim = x.shape
1229
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1230
+
1231
+
1232
+ def expand_virtual_tokens(
1233
+ x: torch.Tensor, n: int, mode: str = "derivative"
1234
+ ) -> torch.Tensor:
1235
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1236
+ batch_size, num_heads, seq_len, head_dim = x.shape
1237
+ device, dtype = x.device, x.dtype
1238
+
1239
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1240
+ """Expand tokens using the derivative trick."""
1241
+ x_pad = torch.cat(
1242
+ [
1243
+ torch.zeros(
1244
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1245
+ ),
1246
+ x,
1247
+ ],
1248
+ dim=2,
1249
+ )
1250
+ coeffs = torch.tensor(
1251
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1252
+ device=device,
1253
+ dtype=dtype,
1254
+ )
1255
+ coeffs /= coeffs.norm(p=1)
1256
+ return (
1257
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1258
+ .flip(-1)
1259
+ .permute(0, 1, 2, 4, 3)
1260
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1261
  )
1262
 
1263
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1264
+ """Expand tokens using the rotative trick."""
1265
+ d_parity = head_dim // 2
1266
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1267
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1268
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1269
+ if head_dim % 2:
1270
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1271
+ else:
1272
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1273
+ x_pairs = x_pairs.unsqueeze(3).expand(
1274
+ batch_size, num_heads, seq_len, n, d_parity, 2
1275
+ )
1276
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1277
+ x0r = x0 * cos - x1 * sin
1278
+ x1r = x0 * sin + x1 * cos
1279
+ rot = torch.stack([x0r, x1r], -1).reshape(
1280
+ batch_size, num_heads, seq_len, n, d_parity * 2
1281
+ )
1282
+ if head_dim % 2:
1283
+ last = (
1284
+ x[..., -1]
1285
+ .unsqueeze(-1)
1286
+ .unsqueeze(3)
1287
+ .expand(batch_size, num_heads, seq_len, n, 1)
1288
+ )
1289
+ rot = torch.cat([rot, last], -1)
1290
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1291
 
1292
+ if mode == "derivative":
1293
+ return derivative_expand(x)
1294
+ if mode == "rotative":
1295
+ return rotative_expand(x)
1296
+ if mode == "combined":
1297
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1298
+ raise ValueError(f"Unknown mode: {mode}")
1299
 
1300
 
1301
  def extract_layer_idx(module_name: str) -> int:
1302
+ """Extract the layer index from a module name string."""
 
 
1303
  match = re.search(r"\.(\d+)\.", module_name)
1304
  if match:
1305
  return int(match.group(1))
1306
  return -1
1307
 
1308
 
1309
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1310
  """Find the embedding weight in a model module."""
1311
  for _, child in module.named_modules():
1312
  if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
 
1318
  return None
1319
 
1320
 
1321
+ def set_trainable_parameters(
1322
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1323
+ ) -> PreTrainedModel:
1324
+ """Freeze model parameters except trainable_patterns."""
1325
+ if trainable_patterns is None:
1326
+ trainable_patterns = [
1327
+ "q_proj",
1328
+ "k_proj",
1329
+ "v_proj",
1330
+ "o_proj",
1331
+ "qkv_proj",
1332
+ "out_proj",
1333
+ "c_attn",
1334
+ "c_proj",
1335
+ "query",
1336
+ "key",
1337
+ "value",
1338
+ ]
1339
 
1340
+ for name, param in model.named_parameters():
1341
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
 
1342
 
1343
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1344
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1345
+ return model
1346
 
 
 
 
 
 
 
 
 
 
 
1347
 
1348
+ def ensure_stability(
1349
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1350
+ ) -> torch.Tensor:
1351
+ """stability forcing"""
1352
+ dtype = tensor.dtype
1353
+ center = (max_val + min_val) / 2
1354
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1355
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1356
+ return tensor.to(dtype=dtype)
1357
 
1358
 
1359
+ def apply_linear_attention_mask(
1360
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1361
+ ) -> torch.Tensor:
1362
+ """Extract if padding --> [B,S]"""
1363
  if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
 
1364
  mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1365
  else:
 
1366
  mask = attention_mask.squeeze(
1367
  dim=tuple(
1368
  i
 
1370
  if attention_mask.shape[i] == 1
1371
  )
1372
  )
1373
+ # Ensure cast to the same dtype as v and convert to binary mask
1374
+ if not (
1375
+ mask.dtype == torch.bool
1376
+ or (
1377
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1378
+ and mask.max() <= 1
1379
+ and mask.min() >= 0
1380
+ )
1381
+ ):
1382
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1383
+ else:
1384
+ mask = mask.to(v.dtype)
1385
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1386
+ if padding_side == "left":
1387
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1388
+ else: # right padding
1389
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1390
  return v * mask
1391
 
1392
 
1393
+ def truncate_attention_mask(
1394
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1395
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1396
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
 
1397
  seq_dim = 1 # convention: (batch, seq, ...)
1398
  seq_len = hidden_states.shape[seq_dim]
1399
  if seq_len > max_length:
 
1415
  return hidden_states, attention_mask
1416
 
1417
 
1418
+ def fast_invert_matrix(
1419
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1420
+ ) -> torch.Tensor:
1421
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1422
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1423
+ chunk_size = tri_tensor.shape[-1]
1424
+
1425
+ for i in range(1, chunk_size):
1426
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1427
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1428
+ ).sum(-2)
1429
+
1430
+ tri_tensor = tri_tensor + torch.eye(
1431
+ chunk_size, dtype=dtype, device=tri_tensor.device
1432
+ )
1433
+ return tri_tensor.to(dtype=dtype)
1434
+
1435
+
1436
  def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1437
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
 
 
 
1438
  for c in range(min(chunk_size, total_l), 0, -1):
1439
  if total_l % c == 0:
1440
  return c
1441
  return 1
1442
 
1443
 
1444
+ ## RARELY
1445
+ def split_qkv(
1446
+ base_attn: nn.Module, qkv: torch.Tensor
1447
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1448
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1449
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1450
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1451
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1452
+ head_dim = getattr(base_attn, "head_dim", None)
1453
+
1454
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1455
+ raise ValueError(
1456
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1457
+ )
1458
+
1459
+ q_len = num_q_heads * head_dim
1460
+ k_len = num_k_heads * head_dim
1461
+ v_len = num_v_heads * head_dim
1462
+
1463
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1464
+ return q, k, v
1465
+
1466
+
1467
+ ## OPTIONAL
1468
  def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1469
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
 
 
 
1470
  src_size = x.shape[dim]
1471
  if src_size == target_size:
1472
  return x
 
1481
  x = F.linear(x, eye) # pylint: disable=not-callable
1482
  x = torch.moveaxis(x, -1, dim)
1483
  return x
1484
+
1485
+
1486
+ def soft_clamp(
1487
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1488
+ ) -> torch.Tensor:
1489
+ """Differentiable clamping for stability"""
1490
+ dtype = x.dtype
1491
+ scale = (max_val - min_val) / 2
1492
+ center = (max_val + min_val) / 2
1493
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1494
+
1495
+
1496
+ def describe(x: torch.Tensor, name="tensor") -> None:
1497
+ """Prints the shape, min, max, mean, and std of a tensor."""
1498
+ stats = (x.min(), x.max(), x.mean(), x.std())
1499
+ print(
1500
+ f"{name} shape: {tuple(x.shape)}, "
1501
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1502
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1503
+ + f"dtype: {x.dtype}, device: {x.device}"
1504
+ )
train_tptt.py CHANGED
@@ -1,20 +1,40 @@
 
 
1
  """
2
  Author : Fabien FURFARO
3
  """
4
 
5
- from transformers import TrainerCallback
 
 
6
 
7
  from .modeling_tptt import LiZAttention
8
 
9
 
10
- class AdjustMaGWeightCallback(TrainerCallback):
11
- """TrainerCallback to schedule mag_weight during training."""
 
 
 
 
 
 
 
12
 
13
  def __init__(
14
- self, model, initial_weight=0.01, final_weight=0.5, transition_step=500
 
 
 
 
 
 
 
15
  ):
16
  self.model = model
17
- # Ensure weights are always float scalars, not tuples/lists
 
 
18
  if isinstance(initial_weight, (tuple, list)):
19
  initial_weight = initial_weight[0]
20
  if isinstance(final_weight, (tuple, list)):
@@ -22,42 +42,101 @@ class AdjustMaGWeightCallback(TrainerCallback):
22
  self.initial_weight = float(initial_weight)
23
  self.final_weight = float(final_weight)
24
 
25
- if isinstance(transition_step, (tuple, list)):
26
- transition_step = transition_step[0]
27
- self.transition_step = int(transition_step)
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def on_step_end(self, args, state, control, **kwargs):
30
  current_step = state.global_step
31
  transition_step = self.transition_step
32
 
33
- # Ensure both are plain ints (not tuple, list, tensor, numpy, etc.)
34
- if isinstance(current_step, (tuple, list)):
35
- current_step = current_step[0]
36
- if hasattr(current_step, "item"):
37
- current_step = int(current_step.item())
38
- else:
39
- current_step = int(current_step)
40
 
41
- if isinstance(transition_step, (tuple, list)):
42
- transition_step = transition_step[0]
43
- if hasattr(transition_step, "item"):
44
- transition_step = int(transition_step.item())
45
- else:
46
- transition_step = int(transition_step)
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- if current_step < transition_step:
49
- weight = self.initial_weight + (self.final_weight - self.initial_weight) * (
50
- current_step / transition_step
51
- )
52
  for _, module in self.model.named_modules():
53
  if isinstance(module, LiZAttention):
54
  module.mag_weight = weight
55
 
 
 
 
 
 
 
 
 
 
 
56
  def on_log(self, args, state, control, logs=None, **kwargs):
57
  mag_weight = None
 
 
58
  for _, module in self.model.named_modules():
59
  if isinstance(module, LiZAttention):
60
  mag_weight = getattr(module, "mag_weight", None)
 
61
  break
62
  if mag_weight is not None and logs is not None:
63
  logs["mag_weight"] = float(mag_weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments
2
+
3
  """
4
  Author : Fabien FURFARO
5
  """
6
 
7
+ from typing import Optional, Union
8
+
9
+ from transformers import PreTrainedModel, TrainerCallback
10
 
11
  from .modeling_tptt import LiZAttention
12
 
13
 
14
+ class LiZACallback(TrainerCallback):
15
+ """
16
+ TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
17
+
18
+ Modes:
19
+ - "gradual": linear interpolation from initial_weight to final_weight.
20
+ - "cyclic": alternate between values in weight_list at each step.
21
+ - "switch": alternately enable/disable linear attention at each step.
22
+ """
23
 
24
  def __init__(
25
+ self,
26
+ model: PreTrainedModel,
27
+ mode: str = "gradual",
28
+ initial_weight: float = 0.0,
29
+ final_weight: float = 0.5,
30
+ transition_step: Union[int, tuple, list] = 100,
31
+ weight_list: Optional[list] = None,
32
+ switch_period: int = 1, # period for switching
33
  ):
34
  self.model = model
35
+ self.mode = mode
36
+
37
+ # Ensure initial_weight is a float scalar, not tuple/list
38
  if isinstance(initial_weight, (tuple, list)):
39
  initial_weight = initial_weight[0]
40
  if isinstance(final_weight, (tuple, list)):
 
42
  self.initial_weight = float(initial_weight)
43
  self.final_weight = float(final_weight)
44
 
45
+ # Ensure transition_step is an int scalar, not tuple/list
46
+ self.transition_step = ensure_int(transition_step)
47
+ if self.mode == "constant":
48
+ # For constant mode, transition_step is not used
49
+ self.initial_weight = self.final_weight
50
+ # For cyclic mode: ensure all weights are float scalars
51
+ if weight_list is not None:
52
+ self.weight_list = [
53
+ float(w[0]) if isinstance(w, (tuple, list)) else float(w)
54
+ for w in weight_list
55
+ ]
56
+ else:
57
+ self.weight_list = [self.initial_weight, self.final_weight]
58
+
59
+ # For switch_alternate mode
60
+ self.switch_period = int(switch_period)
61
 
62
  def on_step_end(self, args, state, control, **kwargs):
63
  current_step = state.global_step
64
  transition_step = self.transition_step
65
 
66
+ # Ensure current_step and transition_step are plain ints
67
+ current_step = ensure_int(current_step)
68
+ transition_step = ensure_int(transition_step)
 
 
 
 
69
 
70
+ # Select mag_weight or enable/disable linear attention according to mode
71
+ if self.mode == "constant":
72
+ # Set mag_weight to final_weight for constant mode
73
+ weight = self.final_weight
74
+ for _, module in self.model.named_modules():
75
+ if isinstance(module, LiZAttention):
76
+ module.mag_weight = weight
77
+
78
+ elif self.mode == "gradual":
79
+ if current_step <= transition_step:
80
+ weight = self.initial_weight + (
81
+ self.final_weight - self.initial_weight
82
+ ) * (current_step / transition_step)
83
+ else:
84
+ weight = self.final_weight
85
+ for _, module in self.model.named_modules():
86
+ if isinstance(module, LiZAttention):
87
+ module.mag_weight = weight
88
 
89
+ elif self.mode == "cyclic":
90
+ idx = current_step % len(self.weight_list)
91
+ weight = self.weight_list[idx]
 
92
  for _, module in self.model.named_modules():
93
  if isinstance(module, LiZAttention):
94
  module.mag_weight = weight
95
 
96
+ elif self.mode == "switch":
97
+ # Alternately enable/disable linear attention every switch_period steps
98
+ disable = (current_step // self.switch_period) % 2 == 0
99
+ for _, module in self.model.named_modules():
100
+ if isinstance(module, LiZAttention):
101
+ module.disable_linear_attn = disable
102
+
103
+ else:
104
+ raise ValueError(f"Unknown mode: {self.mode}")
105
+
106
  def on_log(self, args, state, control, logs=None, **kwargs):
107
  mag_weight = None
108
+ disable_linear_attn = None
109
+ # Log the current mag_weight and disable_linear_attn
110
  for _, module in self.model.named_modules():
111
  if isinstance(module, LiZAttention):
112
  mag_weight = getattr(module, "mag_weight", None)
113
+ disable_linear_attn = getattr(module, "disable_linear_attn", None)
114
  break
115
  if mag_weight is not None and logs is not None:
116
  logs["mag_weight"] = float(mag_weight)
117
+ if disable_linear_attn is not None and logs is not None:
118
+ logs["disable_linear_attn"] = not bool(disable_linear_attn)
119
+
120
+
121
+ def ensure_int(value: Union[int, tuple, list]) -> int:
122
+ """Ensure the value is a plain integer."""
123
+ if isinstance(value, (tuple, list)):
124
+ value = int(value[0])
125
+ if hasattr(value, "item"):
126
+ value = int(value.item())
127
+ return value
128
+
129
+
130
+ class SaveBestModelCallback(TrainerCallback):
131
+ """TrainerCallback to save the best model based on evaluation loss."""
132
+
133
+ def __init__(self):
134
+ self.best_metric = float("inf")
135
+
136
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
137
+ if metrics is not None and "eval_loss" in metrics:
138
+ if metrics["eval_loss"] < self.best_metric:
139
+ self.best_metric = metrics["eval_loss"]
140
+ control.should_save = True # Trigger save
141
+ else:
142
+ control.should_save = False # Skip save