Skip to content

Commit 6e8c864

Browse files
authored
Merge pull request #3 from wolfgitpr/power
Power
2 parents a13bf84 + cf72445 commit 6e8c864

19 files changed

Lines changed: 700 additions & 624 deletions

binarize.py

Lines changed: 76 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import yaml
1010
from tqdm import tqdm
1111

12+
from tools.binarize_util import load_wav, get_curves
1213
from tools.config_utils import load_yaml
1314
from tools.dataset import IndexedDatasetBuilder
1415
from tools.encoder import UnitsEncoder
1516
from tools.get_melspec import MelSpecExtractor
16-
from tools.load_wav import load_wav
1717
from tools.multiprocess_utils import chunked_multiprocess_run
1818

1919
unitsEncoder = None
@@ -56,10 +56,11 @@ def __init__(self, binary_config):
5656
self.max_length = binary_config['max_length']
5757
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5858

59+
self.hop_size = self.melspec_config["hop_size"]
60+
self.window_size = self.melspec_config["window_size"]
5961
self.sample_rate = self.melspec_config["sample_rate"]
60-
self.frame_length = self.melspec_config["hop_length"] / self.sample_rate
62+
self.frame_length = self.hop_size / self.sample_rate
6163

62-
self.hop_size = binary_config['melspec_config']["hop_length"]
6364
self.hubert_channel = binary_config['hubert_config']["channel"]
6465

6566
def get_vocab(self):
@@ -85,8 +86,6 @@ def get_vocab(self):
8586
dict_phonemes.append(ph)
8687

8788
for dataset in self.datasets:
88-
if dataset.get("label_type", "blank") == "blank":
89-
continue
9089
language = dataset.get("language", "blank")
9190
raw_data_dir = dataset["raw_data_dir"]
9291

@@ -195,103 +194,65 @@ def process(self):
195194
# binarize train set
196195
self.binarize("train", meta_data_train, self.binary_folder)
197196

198-
def make_ph_data(self, vocab, T, label_type_id, raw_ph_id_seq, raw_ph_dur):
199-
if label_type_id == 0:
200-
# ph_seq: [S]
201-
ph_id_seq = np.array([]).astype("int32")
197+
def make_ph_data(self, vocab, T, raw_ph_id_seq, raw_ph_dur):
198+
# ph_seq: [S]
199+
ph_id_seq = np.array(raw_ph_id_seq).astype("int32")
200+
not_sp_idx = ph_id_seq != 0
201+
ph_id_seq = ph_id_seq[not_sp_idx]
202202

203-
# ph_edge: [T]
204-
ph_edge = np.zeros([T], dtype="float32")
205-
206-
# ph_frame: [T]
207-
ph_frame = np.zeros(T, dtype="int32")
208-
209-
# ph_time: [T]
210-
ph_time = np.zeros(T, dtype="float32")
211-
212-
# ph_mask: [vocab_size]
213-
ph_mask = np.ones(vocab["vocab_size"], dtype="int32")
214-
elif label_type_id == 1:
215-
# ph_seq: [S]
216-
ph_id_seq = np.array(raw_ph_id_seq).astype("int32")
217-
ph_id_seq = ph_id_seq[ph_id_seq != 0]
218-
219-
if len(ph_id_seq) <= 0:
220-
return None, None, None, None, None
221-
222-
# ph_edge: [T]
223-
ph_edge = np.zeros([T], dtype="float32")
203+
# ph_edge: [T]
204+
ph_dur = np.array(raw_ph_dur).astype("float32")
205+
ph_time = np.array(np.concatenate(([0], ph_dur))).cumsum()
206+
ph_frame = ph_time / self.frame_length
207+
ph_interval = np.stack((ph_frame[:-1], ph_frame[1:]))
208+
ph_time = ph_time[:-1]
209+
ph_time = ph_time[not_sp_idx]
224210

225-
# ph_frame: [T]
226-
ph_frame = np.zeros(T, dtype="int32")
211+
ph_interval = ph_interval[:, not_sp_idx]
212+
ph_id_seq = ph_id_seq
213+
ph_frame = np.unique(ph_interval.flatten())
214+
if ph_frame[-1] >= T:
215+
ph_frame = ph_frame[:-1]
227216

228-
# ph_time: [T]
229-
ph_time = np.zeros(T, dtype="float32")
217+
if len(ph_id_seq) <= 0:
218+
return None, None, None, None, None
230219

231-
# ph_mask: [vocab_size]
232-
ph_mask = np.zeros(vocab["vocab_size"], dtype="int32")
233-
ph_mask[ph_id_seq] = 1
234-
ph_mask[0] = 1
235-
elif label_type_id >= 2:
236-
# ph_seq: [S]
237-
ph_id_seq = np.array(raw_ph_id_seq).astype("int32")
238-
not_sp_idx = ph_id_seq != 0
239-
ph_id_seq = ph_id_seq[not_sp_idx]
240-
241-
# ph_edge: [T]
242-
ph_dur = np.array(raw_ph_dur).astype("float32")
243-
ph_time = np.array(np.concatenate(([0], ph_dur))).cumsum()
244-
ph_frame = ph_time / self.frame_length
245-
ph_interval = np.stack((ph_frame[:-1], ph_frame[1:]))
246-
ph_time = ph_time[:-1]
247-
ph_time = ph_time[not_sp_idx]
248-
249-
ph_interval = ph_interval[:, not_sp_idx]
250-
ph_id_seq = ph_id_seq
251-
ph_frame = np.unique(ph_interval.flatten())
252-
if ph_frame[-1] >= T:
220+
ph_edge = np.zeros([T], dtype="float32")
221+
if len(ph_id_seq) > 0:
222+
if ph_frame[-1] + 0.5 > T:
253223
ph_frame = ph_frame[:-1]
224+
if ph_frame[0] - 0.5 < 0:
225+
ph_frame = ph_frame[1:]
226+
ph_time_int = np.round(ph_frame).astype("int32")
227+
ph_time_fractional = ph_frame - ph_time_int
228+
229+
ph_edge[ph_time_int] = 0.5 + ph_time_fractional
230+
ph_edge[ph_time_int - 1] = 0.5 - ph_time_fractional
231+
ph_edge = ph_edge * 0.8 + 0.1
232+
233+
# ph_frame: [T]
234+
ph_frame = np.zeros(T, dtype="int32")
235+
if len(ph_id_seq) > 0:
236+
for ph_id, st, ed in zip(
237+
ph_id_seq, ph_interval[0], ph_interval[1]
238+
):
239+
if st < 0:
240+
st = 0
241+
if ed > T:
242+
ed = T
243+
ph_frame[int(np.round(st)): int(np.round(ed))] = ph_id
244+
245+
# ph_mask: [vocab_size]
246+
ph_mask = np.zeros(vocab["vocab_size"], dtype="int32")
247+
if len(ph_id_seq) > 0:
248+
ph_mask[ph_id_seq] = 1
249+
ph_mask[0] = 1
254250

255-
if len(ph_id_seq) <= 0:
256-
return None, None, None, None, None
257-
258-
ph_edge = np.zeros([T], dtype="float32")
259-
if len(ph_id_seq) > 0:
260-
if ph_frame[-1] + 0.5 > T:
261-
ph_frame = ph_frame[:-1]
262-
if ph_frame[0] - 0.5 < 0:
263-
ph_frame = ph_frame[1:]
264-
ph_time_int = np.round(ph_frame).astype("int32")
265-
ph_time_fractional = ph_frame - ph_time_int
266-
267-
ph_edge[ph_time_int] = 0.5 + ph_time_fractional
268-
ph_edge[ph_time_int - 1] = 0.5 - ph_time_fractional
269-
ph_edge = ph_edge * 0.8 + 0.1
270-
271-
# ph_frame: [T]
272-
ph_frame = np.zeros(T, dtype="int32")
273-
if len(ph_id_seq) > 0:
274-
for ph_id, st, ed in zip(
275-
ph_id_seq, ph_interval[0], ph_interval[1]
276-
):
277-
if st < 0:
278-
st = 0
279-
if ed > T:
280-
ed = T
281-
ph_frame[int(np.round(st)): int(np.round(ed))] = ph_id
282-
283-
# ph_mask: [vocab_size]
284-
ph_mask = np.zeros(vocab["vocab_size"], dtype="int32")
285-
if len(ph_id_seq) > 0:
286-
ph_mask[ph_id_seq] = 1
287-
ph_mask[0] = 1
288-
else:
289-
return None, None, None, None, None
290251
return ph_id_seq, ph_edge, ph_frame, ph_mask, ph_time
291252

292253
def make_non_speech_ph_data(self, T, ph_id_seq, ph_duration):
293254
if len(ph_id_seq) == 0:
294-
return None, None
255+
return np.zeros((len(self.vocab.keys()) + 1, T), dtype="int32"), []
295256

296257
ph_id_seq = np.array(ph_id_seq, dtype="int32")
297258
ph_dur = np.array(ph_duration, dtype="float32")
@@ -376,22 +337,18 @@ def process_item(self, _item, export_mel=False):
376337
print(f"Skipping {wav_path}, because it doesn't exist")
377338
return None
378339

379-
waveform = load_wav(wav_path, self.device, self.sample_rate) # (L,)
380-
wav_length = len(waveform) / self.sample_rate # seconds
340+
waveform, wav_length, n_frames = load_wav(wav_path, self.sample_rate, self.hop_size,
341+
self.device) # (L,) seconds
381342
if wav_length > self.max_length:
382-
print(
383-
f"Item {wav_path} has a length of {wav_length}s, which is too long, skip it."
384-
)
343+
print(f"Item {wav_path} has a length of {wav_length}s, which is too long, skip it.")
385344
return None
386-
n_frames = waveform.size(-1) // self.hop_size + 1
387345

388-
label_type_id = {"blank": 0, "weak": 1, "full": 2, "evaluate": 3}[_item.label_type]
389-
if label_type_id >= 2:
390-
if len(_item.ph_dur) != len(_item.ph_id_seq): label_type_id = 1
391-
if not _item.ph_id_seq: label_type_id = 0
346+
curves = get_curves(waveform, n_frames, self.window_size, self.hop_size, device=self.device) # [B, C, T]
392347

348+
if len(_item.ph_id_seq) == 0 or len(_item.ph_dur) != len(_item.ph_id_seq):
349+
return None
393350
ph_id_seq, ph_edge, ph_frame, ph_mask, ph_time = self.make_ph_data(
394-
self.vocab, n_frames, label_type_id, _item.ph_id_seq, _item.ph_dur
351+
self.vocab, n_frames, _item.ph_id_seq, _item.ph_dur
395352
)
396353
if ph_id_seq is None:
397354
print(f"Skipping {wav_path}, make ph data failed.")
@@ -425,6 +382,7 @@ def process_item(self, _item, export_mel=False):
425382
return {
426383
'name': str(_item["name"]),
427384
'input_feature': units.cpu().numpy().astype("float32"),
385+
'curves': curves.cpu().numpy().astype("float32"),
428386
'melspec': melspec.cpu().numpy().astype("float32") if export_mel else np.array([0]),
429387
'ph_id_seq': ph_id_seq.astype("int32"),
430388
'ph_edge': ph_edge.astype("float32"),
@@ -434,7 +392,6 @@ def process_item(self, _item, export_mel=False):
434392
'ph_time_raw': np.concatenate(([0], _item.ph_dur)).cumsum()[:-1].astype("float32"),
435393
'ph_seq_raw': _item.ph_seq,
436394
'ph_seq': [ph for ph in _item.ph_seq if self.vocab["vocab"][ph] != 0],
437-
"label_type": label_type_id,
438395
"non_speech_target": non_speech_target.astype("int32"),
439396
"non_speech_intervals": non_speech_intervals.astype("int32"),
440397
"wav_length": wav_length
@@ -454,34 +411,25 @@ def get_meta_data(self):
454411
test_prefixes = dataset.get("test_prefixes", [])
455412

456413
assert raw_data_dir.exists(), f"{raw_data_dir} does not exist."
457-
assert label_type in ["full", "weak", "evaluate", "blank"], \
458-
f"{label_type} not in ['full', 'weak', 'evaluate','blank]."
459-
if label_type == "blank":
460-
df = pd.DataFrame(
461-
columns=["name", "ph_seq", "ph_id_seq", "label_type", "wav_length", "validation"])
462-
wavs_path = [i for i in raw_data_dir.rglob("*.wav")]
463-
df["wav_path"] = wavs_path
464-
df["name"] = df["wav_path"].apply(lambda wav_path: os.path.splitext(os.path.basename(wav_path)))
465-
df["wav_length"] = 0
466-
df["validation"] = False
467-
else:
468-
tuple_prefixes = tuple([x for x in test_prefixes if x] if test_prefixes is not None else [])
414+
assert label_type in ["full", "evaluate"], \
415+
f"{label_type} not in ['full','evaluate']."
416+
417+
tuple_prefixes = tuple([x for x in test_prefixes if x] if test_prefixes is not None else [])
469418

470-
csv_path = raw_data_dir / "transcriptions.csv"
471-
wav_folder = raw_data_dir / "wavs"
472-
assert csv_path.exists() and wav_folder.exists(), f"{csv_path.absolute()} or {wav_folder.absolute()} does not exist."
419+
csv_path = raw_data_dir / "transcriptions.csv"
420+
wav_folder = raw_data_dir / "wavs"
421+
assert csv_path.exists() and wav_folder.exists(), f"{csv_path.absolute()} or {wav_folder.absolute()} does not exist."
473422

474-
df = pd.read_csv(csv_path, dtype=str)
475-
assert "ph_seq" in df.columns, f"{csv_path.absolute()} does not contain 'ph_seq'."
476-
if label_type == "full":
477-
assert "ph_dur" in df.columns, f"full label csv: {csv_path.absolute()} does not contain 'ph_dur'."
423+
df = pd.read_csv(csv_path, dtype=str)
424+
assert "ph_seq" in df.columns, f"{csv_path.absolute()} does not contain 'ph_seq'."
425+
assert "ph_dur" in df.columns, f"full label csv: {csv_path.absolute()} does not contain 'ph_dur'."
478426

479-
if len(tuple_prefixes) > 0:
480-
df["validation"] = df["name"].apply(lambda name: name.startswith(tuple_prefixes))
481-
else:
482-
df["validation"] = False
427+
if len(tuple_prefixes) > 0:
428+
df["validation"] = df["name"].apply(lambda name: name.startswith(tuple_prefixes))
429+
else:
430+
df["validation"] = False
483431

484-
df["wav_path"] = df["name"].apply(lambda name: str(wav_folder / (str(name) + ".wav")))
432+
df["wav_path"] = df["name"].apply(lambda name: str(wav_folder / (str(name) + ".wav")))
485433

486434
df["label_type"] = label_type
487435
df["ph_seq"] = df["ph_seq"].apply(

configs/binarize_config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ hubert_config:
5151
melspec_config:
5252
n_mels: 128
5353
sample_rate: 44100
54-
win_length: 1024
55-
hop_length: 512
54+
window_size: 1024
55+
hop_size: 512
5656
n_fft: 2048
57-
fmin: 40
58-
fmax: 16000
57+
f_min: 40
58+
f_max: 16000
5959
clamp: 0.00001
6060

6161
# 不建议开启

configs/datasets_config.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,4 @@ datasets:
1717
label_type: evaluate
1818
language: yue
1919
test_prefixes:
20-
- xxx
21-
# blank 为无标注wav,不确定是否有效
22-
- raw_data_dir: path/to/spk_1/raw
23-
label_type: blank
20+
- xxx

configs/train_config.yaml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model_name: 0727_hfa_cvnt
1+
model_name: 0918_hfa_power
22

33
# settings
44
float32_matmul_precision: high
@@ -23,20 +23,22 @@ model:
2323
hidden_dims: 192
2424
down_sampling_factor: 2
2525
down_sampling_times: 3
26-
channels_scaleup_factor: 1.3
26+
channels_scaleup_factor: 1.5
27+
dropout: 0.1
28+
29+
curves_attention_dropout: 0.1
2730

2831
cvnt_arg:
29-
mask_ratio: 0.3
32+
mask_ratio: 0.2
3033
encoder_conform_attention_drop: 0.05
31-
32-
num_layers: 3
33-
encoder_conform_dim: 96
34+
num_layers: 4
35+
encoder_conform_dim: 128
3436
encoder_conform_ffn_latent_drop: 0.05
3537
encoder_conform_ffn_out_drop: 0.05
36-
encoder_conform_kernel_size: 31
38+
encoder_conform_kernel_size: 23
3739

3840
optimizer_config:
39-
lr: 0.0005
41+
lr: 0.0003
4042
gamma: 0.9999
4143
total_steps: 20000
4244
muon_args:
@@ -46,13 +48,12 @@ optimizer_config:
4648

4749
loss_config:
4850
losses:
49-
weights: [ 8.0, 0.1, 0.01, 0.1, 2.0, 6.0 ]
50-
enable_RampUpScheduler: [ False,False,False,False,True,False ]
51+
weights: [ 8.0, 0.1, 1.0, 6.0, 10.0 ]
52+
enable_RampUpScheduler: [ False,False,False,True,False ]
5153
function:
5254
num_bins: 10
5355
alpha: 0.999
5456
label_smoothing: 0.08
55-
pseudo_label_ratio: 0.3
5657

5758
# trainer
5859
accelerator: auto

0 commit comments

Comments
 (0)