99import yaml
1010from tqdm import tqdm
1111
12+ from tools .binarize_util import load_wav , get_curves
1213from tools .config_utils import load_yaml
1314from tools .dataset import IndexedDatasetBuilder
1415from tools .encoder import UnitsEncoder
1516from tools .get_melspec import MelSpecExtractor
16- from tools .load_wav import load_wav
1717from tools .multiprocess_utils import chunked_multiprocess_run
1818
1919unitsEncoder = None
@@ -56,10 +56,11 @@ def __init__(self, binary_config):
5656 self .max_length = binary_config ['max_length' ]
5757 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5858
59+ self .hop_size = self .melspec_config ["hop_size" ]
60+ self .window_size = self .melspec_config ["window_size" ]
5961 self .sample_rate = self .melspec_config ["sample_rate" ]
60- self .frame_length = self .melspec_config [ "hop_length" ] / self .sample_rate
62+ self .frame_length = self .hop_size / self .sample_rate
6163
62- self .hop_size = binary_config ['melspec_config' ]["hop_length" ]
6364 self .hubert_channel = binary_config ['hubert_config' ]["channel" ]
6465
6566 def get_vocab (self ):
@@ -85,8 +86,6 @@ def get_vocab(self):
8586 dict_phonemes .append (ph )
8687
8788 for dataset in self .datasets :
88- if dataset .get ("label_type" , "blank" ) == "blank" :
89- continue
9089 language = dataset .get ("language" , "blank" )
9190 raw_data_dir = dataset ["raw_data_dir" ]
9291
@@ -195,103 +194,65 @@ def process(self):
195194 # binarize train set
196195 self .binarize ("train" , meta_data_train , self .binary_folder )
197196
198- def make_ph_data (self , vocab , T , label_type_id , raw_ph_id_seq , raw_ph_dur ):
199- if label_type_id == 0 :
200- # ph_seq: [S]
201- ph_id_seq = np .array ([]).astype ("int32" )
197+ def make_ph_data (self , vocab , T , raw_ph_id_seq , raw_ph_dur ):
198+ # ph_seq: [S]
199+ ph_id_seq = np .array (raw_ph_id_seq ).astype ("int32" )
200+ not_sp_idx = ph_id_seq != 0
201+ ph_id_seq = ph_id_seq [not_sp_idx ]
202202
203- # ph_edge: [T]
204- ph_edge = np .zeros ([T ], dtype = "float32" )
205-
206- # ph_frame: [T]
207- ph_frame = np .zeros (T , dtype = "int32" )
208-
209- # ph_time: [T]
210- ph_time = np .zeros (T , dtype = "float32" )
211-
212- # ph_mask: [vocab_size]
213- ph_mask = np .ones (vocab ["vocab_size" ], dtype = "int32" )
214- elif label_type_id == 1 :
215- # ph_seq: [S]
216- ph_id_seq = np .array (raw_ph_id_seq ).astype ("int32" )
217- ph_id_seq = ph_id_seq [ph_id_seq != 0 ]
218-
219- if len (ph_id_seq ) <= 0 :
220- return None , None , None , None , None
221-
222- # ph_edge: [T]
223- ph_edge = np .zeros ([T ], dtype = "float32" )
203+ # ph_edge: [T]
204+ ph_dur = np .array (raw_ph_dur ).astype ("float32" )
205+ ph_time = np .array (np .concatenate (([0 ], ph_dur ))).cumsum ()
206+ ph_frame = ph_time / self .frame_length
207+ ph_interval = np .stack ((ph_frame [:- 1 ], ph_frame [1 :]))
208+ ph_time = ph_time [:- 1 ]
209+ ph_time = ph_time [not_sp_idx ]
224210
225- # ph_frame: [T]
226- ph_frame = np .zeros (T , dtype = "int32" )
211+ ph_interval = ph_interval [:, not_sp_idx ]
212+ ph_id_seq = ph_id_seq
213+ ph_frame = np .unique (ph_interval .flatten ())
214+ if ph_frame [- 1 ] >= T :
215+ ph_frame = ph_frame [:- 1 ]
227216
228- # ph_time: [T]
229- ph_time = np . zeros ( T , dtype = "float32" )
217+ if len ( ph_id_seq ) <= 0 :
218+ return None , None , None , None , None
230219
231- # ph_mask: [vocab_size]
232- ph_mask = np .zeros (vocab ["vocab_size" ], dtype = "int32" )
233- ph_mask [ph_id_seq ] = 1
234- ph_mask [0 ] = 1
235- elif label_type_id >= 2 :
236- # ph_seq: [S]
237- ph_id_seq = np .array (raw_ph_id_seq ).astype ("int32" )
238- not_sp_idx = ph_id_seq != 0
239- ph_id_seq = ph_id_seq [not_sp_idx ]
240-
241- # ph_edge: [T]
242- ph_dur = np .array (raw_ph_dur ).astype ("float32" )
243- ph_time = np .array (np .concatenate (([0 ], ph_dur ))).cumsum ()
244- ph_frame = ph_time / self .frame_length
245- ph_interval = np .stack ((ph_frame [:- 1 ], ph_frame [1 :]))
246- ph_time = ph_time [:- 1 ]
247- ph_time = ph_time [not_sp_idx ]
248-
249- ph_interval = ph_interval [:, not_sp_idx ]
250- ph_id_seq = ph_id_seq
251- ph_frame = np .unique (ph_interval .flatten ())
252- if ph_frame [- 1 ] >= T :
220+ ph_edge = np .zeros ([T ], dtype = "float32" )
221+ if len (ph_id_seq ) > 0 :
222+ if ph_frame [- 1 ] + 0.5 > T :
253223 ph_frame = ph_frame [:- 1 ]
224+ if ph_frame [0 ] - 0.5 < 0 :
225+ ph_frame = ph_frame [1 :]
226+ ph_time_int = np .round (ph_frame ).astype ("int32" )
227+ ph_time_fractional = ph_frame - ph_time_int
228+
229+ ph_edge [ph_time_int ] = 0.5 + ph_time_fractional
230+ ph_edge [ph_time_int - 1 ] = 0.5 - ph_time_fractional
231+ ph_edge = ph_edge * 0.8 + 0.1
232+
233+ # ph_frame: [T]
234+ ph_frame = np .zeros (T , dtype = "int32" )
235+ if len (ph_id_seq ) > 0 :
236+ for ph_id , st , ed in zip (
237+ ph_id_seq , ph_interval [0 ], ph_interval [1 ]
238+ ):
239+ if st < 0 :
240+ st = 0
241+ if ed > T :
242+ ed = T
243+ ph_frame [int (np .round (st )): int (np .round (ed ))] = ph_id
244+
245+ # ph_mask: [vocab_size]
246+ ph_mask = np .zeros (vocab ["vocab_size" ], dtype = "int32" )
247+ if len (ph_id_seq ) > 0 :
248+ ph_mask [ph_id_seq ] = 1
249+ ph_mask [0 ] = 1
254250
255- if len (ph_id_seq ) <= 0 :
256- return None , None , None , None , None
257-
258- ph_edge = np .zeros ([T ], dtype = "float32" )
259- if len (ph_id_seq ) > 0 :
260- if ph_frame [- 1 ] + 0.5 > T :
261- ph_frame = ph_frame [:- 1 ]
262- if ph_frame [0 ] - 0.5 < 0 :
263- ph_frame = ph_frame [1 :]
264- ph_time_int = np .round (ph_frame ).astype ("int32" )
265- ph_time_fractional = ph_frame - ph_time_int
266-
267- ph_edge [ph_time_int ] = 0.5 + ph_time_fractional
268- ph_edge [ph_time_int - 1 ] = 0.5 - ph_time_fractional
269- ph_edge = ph_edge * 0.8 + 0.1
270-
271- # ph_frame: [T]
272- ph_frame = np .zeros (T , dtype = "int32" )
273- if len (ph_id_seq ) > 0 :
274- for ph_id , st , ed in zip (
275- ph_id_seq , ph_interval [0 ], ph_interval [1 ]
276- ):
277- if st < 0 :
278- st = 0
279- if ed > T :
280- ed = T
281- ph_frame [int (np .round (st )): int (np .round (ed ))] = ph_id
282-
283- # ph_mask: [vocab_size]
284- ph_mask = np .zeros (vocab ["vocab_size" ], dtype = "int32" )
285- if len (ph_id_seq ) > 0 :
286- ph_mask [ph_id_seq ] = 1
287- ph_mask [0 ] = 1
288- else :
289- return None , None , None , None , None
290251 return ph_id_seq , ph_edge , ph_frame , ph_mask , ph_time
291252
292253 def make_non_speech_ph_data (self , T , ph_id_seq , ph_duration ):
293254 if len (ph_id_seq ) == 0 :
294- return None , None
255+ return np . zeros (( len ( self . vocab . keys ()) + 1 , T ), dtype = "int32" ), []
295256
296257 ph_id_seq = np .array (ph_id_seq , dtype = "int32" )
297258 ph_dur = np .array (ph_duration , dtype = "float32" )
@@ -376,22 +337,18 @@ def process_item(self, _item, export_mel=False):
376337 print (f"Skipping { wav_path } , because it doesn't exist" )
377338 return None
378339
379- waveform = load_wav (wav_path , self .device , self .sample_rate ) # (L,)
380- wav_length = len ( waveform ) / self .sample_rate # seconds
340+ waveform , wav_length , n_frames = load_wav (wav_path , self .sample_rate , self .hop_size ,
341+ self .device ) # (L,) seconds
381342 if wav_length > self .max_length :
382- print (
383- f"Item { wav_path } has a length of { wav_length } s, which is too long, skip it."
384- )
343+ print (f"Item { wav_path } has a length of { wav_length } s, which is too long, skip it." )
385344 return None
386- n_frames = waveform .size (- 1 ) // self .hop_size + 1
387345
388- label_type_id = {"blank" : 0 , "weak" : 1 , "full" : 2 , "evaluate" : 3 }[_item .label_type ]
389- if label_type_id >= 2 :
390- if len (_item .ph_dur ) != len (_item .ph_id_seq ): label_type_id = 1
391- if not _item .ph_id_seq : label_type_id = 0
346+ curves = get_curves (waveform , n_frames , self .window_size , self .hop_size , device = self .device ) # [B, C, T]
392347
348+ if len (_item .ph_id_seq ) == 0 or len (_item .ph_dur ) != len (_item .ph_id_seq ):
349+ return None
393350 ph_id_seq , ph_edge , ph_frame , ph_mask , ph_time = self .make_ph_data (
394- self .vocab , n_frames , label_type_id , _item .ph_id_seq , _item .ph_dur
351+ self .vocab , n_frames , _item .ph_id_seq , _item .ph_dur
395352 )
396353 if ph_id_seq is None :
397354 print (f"Skipping { wav_path } , make ph data failed." )
@@ -425,6 +382,7 @@ def process_item(self, _item, export_mel=False):
425382 return {
426383 'name' : str (_item ["name" ]),
427384 'input_feature' : units .cpu ().numpy ().astype ("float32" ),
385+ 'curves' : curves .cpu ().numpy ().astype ("float32" ),
428386 'melspec' : melspec .cpu ().numpy ().astype ("float32" ) if export_mel else np .array ([0 ]),
429387 'ph_id_seq' : ph_id_seq .astype ("int32" ),
430388 'ph_edge' : ph_edge .astype ("float32" ),
@@ -434,7 +392,6 @@ def process_item(self, _item, export_mel=False):
434392 'ph_time_raw' : np .concatenate (([0 ], _item .ph_dur )).cumsum ()[:- 1 ].astype ("float32" ),
435393 'ph_seq_raw' : _item .ph_seq ,
436394 'ph_seq' : [ph for ph in _item .ph_seq if self .vocab ["vocab" ][ph ] != 0 ],
437- "label_type" : label_type_id ,
438395 "non_speech_target" : non_speech_target .astype ("int32" ),
439396 "non_speech_intervals" : non_speech_intervals .astype ("int32" ),
440397 "wav_length" : wav_length
@@ -454,34 +411,25 @@ def get_meta_data(self):
454411 test_prefixes = dataset .get ("test_prefixes" , [])
455412
456413 assert raw_data_dir .exists (), f"{ raw_data_dir } does not exist."
457- assert label_type in ["full" , "weak" , "evaluate" , "blank" ], \
458- f"{ label_type } not in ['full', 'weak', 'evaluate','blank]."
459- if label_type == "blank" :
460- df = pd .DataFrame (
461- columns = ["name" , "ph_seq" , "ph_id_seq" , "label_type" , "wav_length" , "validation" ])
462- wavs_path = [i for i in raw_data_dir .rglob ("*.wav" )]
463- df ["wav_path" ] = wavs_path
464- df ["name" ] = df ["wav_path" ].apply (lambda wav_path : os .path .splitext (os .path .basename (wav_path )))
465- df ["wav_length" ] = 0
466- df ["validation" ] = False
467- else :
468- tuple_prefixes = tuple ([x for x in test_prefixes if x ] if test_prefixes is not None else [])
414+ assert label_type in ["full" , "evaluate" ], \
415+ f"{ label_type } not in ['full','evaluate']."
416+
417+ tuple_prefixes = tuple ([x for x in test_prefixes if x ] if test_prefixes is not None else [])
469418
470- csv_path = raw_data_dir / "transcriptions.csv"
471- wav_folder = raw_data_dir / "wavs"
472- assert csv_path .exists () and wav_folder .exists (), f"{ csv_path .absolute ()} or { wav_folder .absolute ()} does not exist."
419+ csv_path = raw_data_dir / "transcriptions.csv"
420+ wav_folder = raw_data_dir / "wavs"
421+ assert csv_path .exists () and wav_folder .exists (), f"{ csv_path .absolute ()} or { wav_folder .absolute ()} does not exist."
473422
474- df = pd .read_csv (csv_path , dtype = str )
475- assert "ph_seq" in df .columns , f"{ csv_path .absolute ()} does not contain 'ph_seq'."
476- if label_type == "full" :
477- assert "ph_dur" in df .columns , f"full label csv: { csv_path .absolute ()} does not contain 'ph_dur'."
423+ df = pd .read_csv (csv_path , dtype = str )
424+ assert "ph_seq" in df .columns , f"{ csv_path .absolute ()} does not contain 'ph_seq'."
425+ assert "ph_dur" in df .columns , f"full label csv: { csv_path .absolute ()} does not contain 'ph_dur'."
478426
479- if len (tuple_prefixes ) > 0 :
480- df ["validation" ] = df ["name" ].apply (lambda name : name .startswith (tuple_prefixes ))
481- else :
482- df ["validation" ] = False
427+ if len (tuple_prefixes ) > 0 :
428+ df ["validation" ] = df ["name" ].apply (lambda name : name .startswith (tuple_prefixes ))
429+ else :
430+ df ["validation" ] = False
483431
484- df ["wav_path" ] = df ["name" ].apply (lambda name : str (wav_folder / (str (name ) + ".wav" )))
432+ df ["wav_path" ] = df ["name" ].apply (lambda name : str (wav_folder / (str (name ) + ".wav" )))
485433
486434 df ["label_type" ] = label_type
487435 df ["ph_seq" ] = df ["ph_seq" ].apply (
0 commit comments