forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess_data.py
381 lines (316 loc) · 14.6 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing large data for pretraining."""
import argparse
import math
import json
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import gzip
import glob
import torch
import numpy as np
import multiprocessing
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def split(self, json_line):
data = json.loads(json_line)
output = {}
for key in self.args.json_keys:
text = data[key]
max_len = 1000000
tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
output[key] = [tokens for partial in tokens_list for tokens in partial]
return json.dumps(output), len(json_line)
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
lens = {}
for key in self.args.json_keys:
text = data[key]
if isinstance(text, list):
sentences = text
else:
sentences = [text]
doc_ids = []
sentence_lens = []
for sentence in sentences:
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.extend(sentence_ids)
sentence_lens.append(len(sentence_ids))
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids.append(Encoder.tokenizer.eod)
ids[key] = doc_ids
lens[key] = sentence_lens
return ids, lens, len(json_line)
class Partition(object):
def __init__(self, args, workers):
self.args = args
self.workers = workers
def print_processing_stats(self, count, proc_start, total_bytes_processed):
if count % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {count} documents",
f"({count/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
def split_sentences(self, file_name):
input_file_name, output_file_name = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
fout = open(output_file_name, 'w')
encoder = Encoder(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
split_docs = pool.imap(encoder.split, fin, 32)
proc_start = time.time()
total_bytes_processed = 0
for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
total_bytes_processed += bytes_processed
fout.write(doc + "\n")
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
fout.close()
def process_json_file(self, file_name):
input_file_name, output_prefix = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
startup_start = time.time()
encoder = Encoder(self.args)
tokenizer = build_tokenizer(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 32)
level = "document"
if self.args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=self.args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key in doc.keys():
builders[key].add_doc(doc[key], sentence_lens[key])
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
builders[key].finalize(output_idx_files[key])
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer', 'NullTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--vocab-size', default=786,
help='size of vocab for use with NullTokenizer')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, required=True,
help=('Number of worker processes to launch.'
'A good default for fast pre-processing '
'is: (workers * partitions) = available CPU cores.'))
group.add_argument('--partitions', type=int, default=1,
help='Number of file partitions')
group.add_argument('--log-interval', type=int, default=1000,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
print("Are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 1
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def get_file_name(args, file_id):
file_name, extension = os.path.splitext(args.input)
input_file_name = file_name + "_" + str(file_id) + extension
sentence_split_file = file_name + "_ss_" + str(file_id) + extension
output_prefix = args.output_prefix + "_" + str(file_id)
file_names = {
'partition': input_file_name,
'sentence_split': sentence_split_file,
'output_prefix': output_prefix}
return file_names
def check_files_exist(in_ss_out_names, key, num_partitions):
for i in range(num_partitions):
if not os.path.exists(in_ss_out_names[i][key]):
return False
return True
def main():
args = get_args()
if args.split_sentences:
if nltk_available:
nltk.download("punkt", quiet=True)
else:
raise Exception(
"nltk library required for sentence splitting is not available.")
in_ss_out_names = []
if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
in_file_names = glob.glob(args.input)
# create .jsonl parition files
for idx in range(args.partitions):
in_ss_out_name = get_file_name(args, idx)
in_ss_out_names.append(in_ss_out_name)
# check to see if paritions were already created
partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
if not partitions_present and not split_sentences_present:
# populate .jsonl partition files from parent files
partitioned_input_files = []
for idx in range(args.partitions):
partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
partitioned_input_files.append(partitioned_input_file)
index = 0
for in_file_name in in_file_names:
# support for gzip files
if in_file_name.endswith(".gz"):
fin = gzip.open(in_file_name, 'rt')
else:
fin = open(in_file_name, 'r', encoding='utf-8')
for line in fin:
partitioned_input_files[index].write(line)
index = (index + 1)%args.partitions
fin.close()
for idx in range(args.partitions):
partitioned_input_files[idx].close()
assert args.workers % args.partitions == 0
partition = Partition(args, args.workers//args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
# split sentences in partition files
if args.split_sentences and not split_sentences_present:
processes = []
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.split_sentences,
args=((name['partition'], name['sentence_split']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# encode partition files in parallel
processes = []
input_key = 'sentence_split' if args.split_sentences else 'partition'
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.process_json_file,
args=((name[input_key], name['output_prefix']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# merge bin/idx partitions
level = "document"
if args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
tokenizer = build_tokenizer(args)
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
for name in in_ss_out_names:
parition_output_prefix = name['output_prefix']
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level)
builders[key].merge_file_(full_partition_output_prefix)
builders[key].finalize(output_idx_files[key])
if __name__ == '__main__':
main()