Simon Moisselin simonn94 commited on
Commit
95e087e
·
unverified ·
1 Parent(s): 3063d29

models : add ggml_to_pt script (#1042)

Browse files

* adding ggml_to_pt

* typo sys too many args

* fixing swap errors dimensions

---------

Co-authored-by: simonMoisselin <[email protected]>

Files changed (1) hide show
  1. models/ggml_to_pt.py +109 -0
models/ggml_to_pt.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import struct
2
+ import torch
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ if len(sys.argv) < 3:
9
+ print(
10
+ "Usage: convert-ggml-to-pt.py model.bin dir-output\n")
11
+ sys.exit(1)
12
+
13
+ fname_inp = Path(sys.argv[1])
14
+ dir_out = Path(sys.argv[2])
15
+ fname_out = dir_out / "torch-model.pt"
16
+
17
+
18
+
19
+ # Open the ggml file
20
+ with open(fname_inp, "rb") as f:
21
+ # Read magic number and hyperparameters
22
+ magic_number, n_vocab, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, n_text_ctx, n_text_state, n_text_head, n_text_layer, n_mels, use_f16 = struct.unpack("12i", f.read(48))
23
+ print(f"Magic number: {magic_number}")
24
+ print(f"Vocab size: {n_vocab}")
25
+ print(f"Audio context size: {n_audio_ctx}")
26
+ print(f"Audio state size: {n_audio_state}")
27
+ print(f"Audio head size: {n_audio_head}")
28
+ print(f"Audio layer size: {n_audio_layer}")
29
+ print(f"Text context size: {n_text_ctx}")
30
+ print(f"Text head size: {n_text_head}")
31
+ print(f"Mel size: {n_mels}")
32
+ # Read mel filters
33
+ # mel_filters = np.fromfile(f, dtype=np.float32, count=n_mels * 2).reshape(n_mels, 2)
34
+ # print(f"Mel filters: {mel_filters}")
35
+ filters_shape_0 = struct.unpack("i", f.read(4))[0]
36
+ print(f"Filters shape 0: {filters_shape_0}")
37
+ filters_shape_1 = struct.unpack("i", f.read(4))[0]
38
+ print(f"Filters shape 1: {filters_shape_1}")
39
+
40
+ # Read tokenizer tokens
41
+ # bytes = f.read(4)
42
+ # print(bytes)
43
+
44
+
45
+ # for i in range(filters.shape[0]):
46
+ # for j in range(filters.shape[1]):
47
+ # fout.write(struct.pack("f", filters[i][j]))
48
+ mel_filters = np.zeros((filters_shape_0, filters_shape_1))
49
+
50
+ for i in range(filters_shape_0):
51
+ for j in range(filters_shape_1):
52
+ mel_filters[i][j] = struct.unpack("f", f.read(4))[0]
53
+
54
+ bytes_data = f.read(4)
55
+ num_tokens = struct.unpack("i", bytes_data)[0]
56
+ tokens = {}
57
+
58
+
59
+ for _ in range(num_tokens):
60
+ token_len = struct.unpack("i", f.read(4))[0]
61
+ token = f.read(token_len)
62
+ tokens[token] = {}
63
+
64
+ # Read model variables
65
+ model_state_dict = OrderedDict()
66
+ while True:
67
+ try:
68
+ n_dims, name_length, ftype = struct.unpack("iii", f.read(12))
69
+ except struct.error:
70
+ break # End of file
71
+ dims = [struct.unpack("i", f.read(4))[0] for _ in range(n_dims)]
72
+ dims = dims[::-1]
73
+ name = f.read(name_length).decode("utf-8")
74
+ if ftype == 1: # f16
75
+ data = np.fromfile(f, dtype=np.float16, count=np.prod(dims)).reshape(dims)
76
+ else: # f32
77
+ data = np.fromfile(f, dtype=np.float32, count=np.prod(dims)).reshape(dims)
78
+
79
+
80
+ if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
81
+
82
+ data = data[:, 0]
83
+
84
+
85
+ model_state_dict[name] = torch.from_numpy(data)
86
+
87
+ # Now you have the model's state_dict stored in model_state_dict
88
+ # You can load this state_dict into a model with the same architecture
89
+
90
+ # dims = ModelDimensions(**checkpoint["dims"])
91
+ # model = Whisper(dims)
92
+ from whisper import Whisper, ModelDimensions
93
+ dims = ModelDimensions(
94
+ n_mels=n_mels,
95
+ n_audio_ctx=n_audio_ctx,
96
+ n_audio_state=n_audio_state,
97
+ n_audio_head=n_audio_head,
98
+ n_audio_layer=n_audio_layer,
99
+ n_text_ctx=n_text_ctx,
100
+ n_text_state=n_text_state,
101
+ n_text_head=n_text_head,
102
+ n_text_layer=n_text_layer,
103
+ n_vocab=n_vocab,
104
+ )
105
+ model = Whisper(dims) # Replace with your model's class
106
+ model.load_state_dict(model_state_dict)
107
+
108
+ # Save the model in PyTorch format
109
+ torch.save(model.state_dict(), fname_out)