update
Browse files- examples/evaluation/step_1_run_evaluation.py +9 -3
- examples/evaluation/step_2_show_metrics.py +65 -40
- examples/evaluation/step_3_show_vad.py +12 -14
- examples/fsmn_vad_by_webrtcvad/step_4_train_model.py +1 -1
- examples/silero_vad_by_webrtcvad/run.sh +1 -1
- examples/silero_vad_by_webrtcvad/step_4_train_model.py +5 -3
- examples/silero_vad_by_webrtcvad/step_5_export_model.py +1 -1
- log.py +45 -8
- main.py +15 -7
- toolbox/pydub/volume.py +39 -0
- toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py +3 -2
- toolbox/torchaudio/models/vad/native_silero_vad/__init__.py +6 -0
- toolbox/torchaudio/models/vad/native_silero_vad/inference_native_silero_vad_onnx.py +198 -0
- toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py +1 -4
examples/evaluation/step_1_run_evaluation.py
CHANGED
|
@@ -26,7 +26,14 @@ def get_args():
|
|
| 26 |
)
|
| 27 |
parser.add_argument(
|
| 28 |
"--output_file",
|
| 29 |
-
default=r"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
type=str
|
| 31 |
)
|
| 32 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
|
@@ -110,8 +117,7 @@ def main():
|
|
| 110 |
min_silence_length=6,
|
| 111 |
max_speech_length=100000,
|
| 112 |
min_speech_length=15,
|
| 113 |
-
|
| 114 |
-
engine="silero-vad-by-webrtcvad-nx2-dns3",
|
| 115 |
api_name="/when_click_vad_button"
|
| 116 |
)
|
| 117 |
js = json.loads(message)
|
|
|
|
| 26 |
)
|
| 27 |
parser.add_argument(
|
| 28 |
"--output_file",
|
| 29 |
+
default=r"native_silero_vad.jsonl",
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--vad_engine",
|
| 34 |
+
# default="fsmn-vad-by-webrtcvad-nx2-dns3",
|
| 35 |
+
# default="silero-vad-by-webrtcvad-nx2-dns3",
|
| 36 |
+
default="native_silero_vad",
|
| 37 |
type=str
|
| 38 |
)
|
| 39 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
|
|
|
| 117 |
min_silence_length=6,
|
| 118 |
max_speech_length=100000,
|
| 119 |
min_speech_length=15,
|
| 120 |
+
engine=args.vad_engine,
|
|
|
|
| 121 |
api_name="/when_click_vad_button"
|
| 122 |
)
|
| 123 |
js = json.loads(message)
|
examples/evaluation/step_2_show_metrics.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import argparse
|
| 4 |
import json
|
| 5 |
import os
|
|
|
|
| 6 |
import sys
|
| 7 |
|
| 8 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
|
@@ -16,53 +17,77 @@ def get_args():
|
|
| 16 |
|
| 17 |
parser.add_argument(
|
| 18 |
"--eval_file",
|
| 19 |
-
default=r"
|
| 20 |
type=str
|
| 21 |
)
|
| 22 |
args = parser.parse_args()
|
| 23 |
return args
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def main():
|
| 27 |
-
args = get_args()
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
return
|
| 67 |
|
| 68 |
|
|
|
|
| 3 |
import argparse
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
import sys
|
| 8 |
|
| 9 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
|
|
|
| 17 |
|
| 18 |
parser.add_argument(
|
| 19 |
"--eval_file",
|
| 20 |
+
# default=r"native_silero_vad.jsonl",
|
| 21 |
type=str
|
| 22 |
)
|
| 23 |
args = parser.parse_args()
|
| 24 |
return args
|
| 25 |
|
| 26 |
|
| 27 |
+
evaluation_files = [
|
| 28 |
+
"native_silero_vad.jsonl",
|
| 29 |
+
"fsmn-vad.jsonl",
|
| 30 |
+
"silero-vad.jsonl"
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
def main():
|
| 35 |
+
# args = get_args()
|
| 36 |
+
|
| 37 |
+
for eval_file in evaluation_files:
|
| 38 |
+
eval_file = Path(eval_file)
|
| 39 |
+
total = 0
|
| 40 |
+
total_duration = 0
|
| 41 |
+
total_accuracy = 0
|
| 42 |
+
total_precision = 0
|
| 43 |
+
total_recall = 0
|
| 44 |
+
total_f1 = 0
|
| 45 |
+
|
| 46 |
+
average_accuracy = 0
|
| 47 |
+
average_precision = 0
|
| 48 |
+
average_recall = 0
|
| 49 |
+
average_f1 = 0
|
| 50 |
+
|
| 51 |
+
# progress_bar = tqdm(desc=eval_file.name)
|
| 52 |
+
with open(eval_file.as_posix(), "r", encoding="utf-8") as f:
|
| 53 |
+
for row in f:
|
| 54 |
+
row = json.loads(row)
|
| 55 |
+
duration = row["duration"]
|
| 56 |
+
accuracy = row["accuracy"]
|
| 57 |
+
precision = row["precision"]
|
| 58 |
+
recall = row["recall"]
|
| 59 |
+
f1 = row["f1"]
|
| 60 |
+
|
| 61 |
+
total += 1
|
| 62 |
+
total_duration += duration
|
| 63 |
+
total_accuracy += accuracy * duration
|
| 64 |
+
total_precision += precision * duration
|
| 65 |
+
total_recall += recall * duration
|
| 66 |
+
total_f1 += f1 * duration
|
| 67 |
+
|
| 68 |
+
average_accuracy = total_accuracy / total_duration
|
| 69 |
+
average_precision = total_precision / total_duration
|
| 70 |
+
average_recall = total_recall / total_duration
|
| 71 |
+
average_f1 = total_f1 / total_duration
|
| 72 |
+
|
| 73 |
+
# progress_bar.update(1)
|
| 74 |
+
# progress_bar.set_postfix({
|
| 75 |
+
# "total": total,
|
| 76 |
+
# "accuracy": average_accuracy,
|
| 77 |
+
# "precision": average_precision,
|
| 78 |
+
# "recall": average_recall,
|
| 79 |
+
# "f1": average_f1,
|
| 80 |
+
# "total_duration": f"{round(total_duration / 60, 4)}min",
|
| 81 |
+
# })
|
| 82 |
+
summary = (f"{eval_file.name}, "
|
| 83 |
+
f"total: {total}, "
|
| 84 |
+
f"accuracy: {average_accuracy}, "
|
| 85 |
+
f"precision: {average_precision}, "
|
| 86 |
+
f"recall: {average_recall}, "
|
| 87 |
+
f"f1: {average_f1}, "
|
| 88 |
+
f"total_duration: {f"{round(total_duration / 60, 4)}min"}, "
|
| 89 |
+
)
|
| 90 |
+
print(summary)
|
| 91 |
return
|
| 92 |
|
| 93 |
|
examples/evaluation/step_3_show_vad.py
CHANGED
|
@@ -51,10 +51,17 @@ def show_image(signal: np.ndarray,
|
|
| 51 |
plt.show()
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def main():
|
| 55 |
-
args = get_args()
|
| 56 |
|
| 57 |
-
with open(
|
| 58 |
for row in f:
|
| 59 |
row = json.loads(row)
|
| 60 |
filename = row["filename"]
|
|
@@ -77,25 +84,16 @@ def main():
|
|
| 77 |
begin = int(begin * sample_rate)
|
| 78 |
end = int(end * sample_rate)
|
| 79 |
ground_truth_probs[begin:end] = 1
|
|
|
|
| 80 |
prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
|
| 81 |
for begin, end in prediction:
|
| 82 |
begin = int(begin * sample_rate)
|
| 83 |
end = int(end * sample_rate)
|
| 84 |
prediction_probs[begin:end] = 1
|
| 85 |
|
| 86 |
-
# p = encoder_num_layers * (encoder_kernel_size - 1) // 2 * hop_size * sample_rate
|
| 87 |
-
p = 3 * (3 - 1) // 2 * 80
|
| 88 |
-
p = int(p)
|
| 89 |
-
print(f"p: {p}")
|
| 90 |
-
prediction_probs = np.concat(
|
| 91 |
-
[
|
| 92 |
-
prediction_probs[p:], prediction_probs[-p:]
|
| 93 |
-
],
|
| 94 |
-
axis=-1
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
show_image(signal,
|
| 98 |
-
ground_truth_probs,
|
|
|
|
| 99 |
sample_rate=sample_rate,
|
| 100 |
)
|
| 101 |
return
|
|
|
|
| 51 |
plt.show()
|
| 52 |
|
| 53 |
|
| 54 |
+
evaluation_files = [
|
| 55 |
+
# "native_silero_vad.jsonl",
|
| 56 |
+
"fsmn-vad.jsonl",
|
| 57 |
+
"silero-vad.jsonl"
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
def main():
|
| 62 |
+
# args = get_args()
|
| 63 |
|
| 64 |
+
with open(evaluation_files[0], "r", encoding="utf-8") as f:
|
| 65 |
for row in f:
|
| 66 |
row = json.loads(row)
|
| 67 |
filename = row["filename"]
|
|
|
|
| 84 |
begin = int(begin * sample_rate)
|
| 85 |
end = int(end * sample_rate)
|
| 86 |
ground_truth_probs[begin:end] = 1
|
| 87 |
+
|
| 88 |
prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
|
| 89 |
for begin, end in prediction:
|
| 90 |
begin = int(begin * sample_rate)
|
| 91 |
end = int(end * sample_rate)
|
| 92 |
prediction_probs[begin:end] = 1
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
show_image(signal,
|
| 95 |
+
ground_truth_probs,
|
| 96 |
+
prediction_probs,
|
| 97 |
sample_rate=sample_rate,
|
| 98 |
)
|
| 99 |
return
|
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
|
@@ -127,7 +127,7 @@ def main():
|
|
| 127 |
max_wave_value=32768.0,
|
| 128 |
min_snr_db=config.min_snr_db,
|
| 129 |
max_snr_db=config.max_snr_db,
|
| 130 |
-
do_volume_enhancement=
|
| 131 |
# skip=225000,
|
| 132 |
)
|
| 133 |
valid_dataset = VadPaddingJsonlDataset(
|
|
|
|
| 127 |
max_wave_value=32768.0,
|
| 128 |
min_snr_db=config.min_snr_db,
|
| 129 |
max_snr_db=config.max_snr_db,
|
| 130 |
+
do_volume_enhancement=False,
|
| 131 |
# skip=225000,
|
| 132 |
)
|
| 133 |
valid_dataset = VadPaddingJsonlDataset(
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
|
| 5 |
bash run.sh --stage 3 --stop_stage 5 --system_version centos \
|
| 6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
-
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 8 |
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 9 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 10 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
|
|
|
| 4 |
|
| 5 |
bash run.sh --stage 3 --stop_stage 5 --system_version centos \
|
| 6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
+
--final_model_name silero-vad-by-webrtcvad-nx2-dns3-20250813 \
|
| 8 |
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 9 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 10 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
examples/silero_vad_by_webrtcvad/step_4_train_model.py
CHANGED
|
@@ -127,7 +127,7 @@ def main():
|
|
| 127 |
max_wave_value=32768.0,
|
| 128 |
min_snr_db=config.min_snr_db,
|
| 129 |
max_snr_db=config.max_snr_db,
|
| 130 |
-
do_volume_enhancement=
|
| 131 |
# skip=225000,
|
| 132 |
)
|
| 133 |
valid_dataset = VadPaddingJsonlDataset(
|
|
@@ -271,7 +271,8 @@ def main():
|
|
| 271 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 272 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 273 |
|
| 274 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
|
|
|
| 275 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 276 |
logger.info(f"find nan or inf in loss. continue.")
|
| 277 |
continue
|
|
@@ -352,7 +353,8 @@ def main():
|
|
| 352 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 353 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 354 |
|
| 355 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
|
|
|
| 356 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 357 |
logger.info(f"find nan or inf in loss. continue.")
|
| 358 |
continue
|
|
|
|
| 127 |
max_wave_value=32768.0,
|
| 128 |
min_snr_db=config.min_snr_db,
|
| 129 |
max_snr_db=config.max_snr_db,
|
| 130 |
+
do_volume_enhancement=False,
|
| 131 |
# skip=225000,
|
| 132 |
)
|
| 133 |
valid_dataset = VadPaddingJsonlDataset(
|
|
|
|
| 271 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 272 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 273 |
|
| 274 |
+
# loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
| 275 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 1.0 * lsnr_loss
|
| 276 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 277 |
logger.info(f"find nan or inf in loss. continue.")
|
| 278 |
continue
|
|
|
|
| 353 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 354 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 355 |
|
| 356 |
+
# loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
| 357 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 1.0 * lsnr_loss
|
| 358 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 359 |
logger.info(f"find nan or inf in loss. continue.")
|
| 360 |
continue
|
examples/silero_vad_by_webrtcvad/step_5_export_model.py
CHANGED
|
@@ -94,7 +94,7 @@ def main():
|
|
| 94 |
"new_lstm_hidden_state": {2: "batch_size"},
|
| 95 |
})
|
| 96 |
|
| 97 |
-
ort_session = ort.InferenceSession("
|
| 98 |
input_feed = {
|
| 99 |
"inputs": inputs.numpy(),
|
| 100 |
"encoder_in_cache": encoder_in_cache.numpy(),
|
|
|
|
| 94 |
"new_lstm_hidden_state": {2: "batch_size"},
|
| 95 |
})
|
| 96 |
|
| 97 |
+
ort_session = ort.InferenceSession("model.onnx")
|
| 98 |
input_feed = {
|
| 99 |
"inputs": inputs.numpy(),
|
| 100 |
"encoder_in_cache": encoder_in_cache.numpy(),
|
log.py
CHANGED
|
@@ -15,8 +15,43 @@ def get_converter(tz_info: str = "Asia/Shanghai"):
|
|
| 15 |
return converter
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
| 19 |
-
fmt = "%(asctime)s
|
| 20 |
|
| 21 |
formatter = logging.Formatter(
|
| 22 |
fmt=fmt,
|
|
@@ -38,11 +73,12 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
| 38 |
backupCount=2,
|
| 39 |
)
|
| 40 |
main_info_file_handler.setLevel(logging.INFO)
|
| 41 |
-
main_info_file_handler.setFormatter(
|
| 42 |
main_logger.addHandler(main_info_file_handler)
|
| 43 |
|
| 44 |
# http
|
| 45 |
http_logger = logging.getLogger("http")
|
|
|
|
| 46 |
http_file_handler = RotatingFileHandler(
|
| 47 |
filename=os.path.join(log_directory, "http.log"),
|
| 48 |
maxBytes=100*1024*1024, # 100MB
|
|
@@ -50,11 +86,12 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
| 50 |
backupCount=2,
|
| 51 |
)
|
| 52 |
http_file_handler.setLevel(logging.DEBUG)
|
| 53 |
-
http_file_handler.setFormatter(
|
| 54 |
http_logger.addHandler(http_file_handler)
|
| 55 |
|
| 56 |
# api
|
| 57 |
api_logger = logging.getLogger("api")
|
|
|
|
| 58 |
api_file_handler = RotatingFileHandler(
|
| 59 |
filename=os.path.join(log_directory, "api.log"),
|
| 60 |
maxBytes=10*1024*1024, # 10MB
|
|
@@ -62,7 +99,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
| 62 |
backupCount=2,
|
| 63 |
)
|
| 64 |
api_file_handler.setLevel(logging.DEBUG)
|
| 65 |
-
api_file_handler.setFormatter(
|
| 66 |
api_logger.addHandler(api_file_handler)
|
| 67 |
|
| 68 |
# alarm
|
|
@@ -74,7 +111,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
| 74 |
backupCount=2,
|
| 75 |
)
|
| 76 |
alarm_file_handler.setLevel(logging.DEBUG)
|
| 77 |
-
alarm_file_handler.setFormatter(
|
| 78 |
alarm_logger.addHandler(alarm_file_handler)
|
| 79 |
|
| 80 |
debug_file_handler = RotatingFileHandler(
|
|
@@ -84,7 +121,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
| 84 |
backupCount=2,
|
| 85 |
)
|
| 86 |
debug_file_handler.setLevel(logging.DEBUG)
|
| 87 |
-
debug_file_handler.setFormatter(
|
| 88 |
|
| 89 |
info_file_handler = RotatingFileHandler(
|
| 90 |
filename=os.path.join(log_directory, "info.log"),
|
|
@@ -93,7 +130,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
| 93 |
backupCount=2,
|
| 94 |
)
|
| 95 |
info_file_handler.setLevel(logging.INFO)
|
| 96 |
-
info_file_handler.setFormatter(
|
| 97 |
|
| 98 |
error_file_handler = RotatingFileHandler(
|
| 99 |
filename=os.path.join(log_directory, "error.log"),
|
|
@@ -102,7 +139,7 @@ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
|
| 102 |
backupCount=2,
|
| 103 |
)
|
| 104 |
error_file_handler.setLevel(logging.ERROR)
|
| 105 |
-
error_file_handler.setFormatter(
|
| 106 |
|
| 107 |
logging.basicConfig(
|
| 108 |
level=logging.DEBUG,
|
|
|
|
| 15 |
return converter
|
| 16 |
|
| 17 |
|
| 18 |
+
def setup_stream(tz_info: str = "Asia/Shanghai"):
|
| 19 |
+
fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
|
| 20 |
+
|
| 21 |
+
formatter = logging.Formatter(
|
| 22 |
+
fmt=fmt,
|
| 23 |
+
datefmt="%Y-%m-%d %H:%M:%S %z"
|
| 24 |
+
)
|
| 25 |
+
formatter.converter = get_converter(tz_info)
|
| 26 |
+
|
| 27 |
+
stream_handler = logging.StreamHandler()
|
| 28 |
+
stream_handler.setLevel(logging.INFO)
|
| 29 |
+
stream_handler.setFormatter(formatter)
|
| 30 |
+
|
| 31 |
+
# main
|
| 32 |
+
main_logger = logging.getLogger("main")
|
| 33 |
+
main_logger.addHandler(stream_handler)
|
| 34 |
+
|
| 35 |
+
# http
|
| 36 |
+
http_logger = logging.getLogger("http")
|
| 37 |
+
http_logger.addHandler(stream_handler)
|
| 38 |
+
|
| 39 |
+
# api
|
| 40 |
+
api_logger = logging.getLogger("api")
|
| 41 |
+
api_logger.addHandler(stream_handler)
|
| 42 |
+
|
| 43 |
+
logging.basicConfig(
|
| 44 |
+
level=logging.DEBUG,
|
| 45 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
| 46 |
+
handlers=[
|
| 47 |
+
|
| 48 |
+
]
|
| 49 |
+
)
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
|
| 53 |
def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
| 54 |
+
fmt = "%(asctime)s|%(name)s|%(levelname)s|%(filename)s|%(lineno)d|%(message)s"
|
| 55 |
|
| 56 |
formatter = logging.Formatter(
|
| 57 |
fmt=fmt,
|
|
|
|
| 73 |
backupCount=2,
|
| 74 |
)
|
| 75 |
main_info_file_handler.setLevel(logging.INFO)
|
| 76 |
+
main_info_file_handler.setFormatter(formatter)
|
| 77 |
main_logger.addHandler(main_info_file_handler)
|
| 78 |
|
| 79 |
# http
|
| 80 |
http_logger = logging.getLogger("http")
|
| 81 |
+
http_logger.addHandler(stream_handler)
|
| 82 |
http_file_handler = RotatingFileHandler(
|
| 83 |
filename=os.path.join(log_directory, "http.log"),
|
| 84 |
maxBytes=100*1024*1024, # 100MB
|
|
|
|
| 86 |
backupCount=2,
|
| 87 |
)
|
| 88 |
http_file_handler.setLevel(logging.DEBUG)
|
| 89 |
+
http_file_handler.setFormatter(formatter)
|
| 90 |
http_logger.addHandler(http_file_handler)
|
| 91 |
|
| 92 |
# api
|
| 93 |
api_logger = logging.getLogger("api")
|
| 94 |
+
api_logger.addHandler(stream_handler)
|
| 95 |
api_file_handler = RotatingFileHandler(
|
| 96 |
filename=os.path.join(log_directory, "api.log"),
|
| 97 |
maxBytes=10*1024*1024, # 10MB
|
|
|
|
| 99 |
backupCount=2,
|
| 100 |
)
|
| 101 |
api_file_handler.setLevel(logging.DEBUG)
|
| 102 |
+
api_file_handler.setFormatter(formatter)
|
| 103 |
api_logger.addHandler(api_file_handler)
|
| 104 |
|
| 105 |
# alarm
|
|
|
|
| 111 |
backupCount=2,
|
| 112 |
)
|
| 113 |
alarm_file_handler.setLevel(logging.DEBUG)
|
| 114 |
+
alarm_file_handler.setFormatter(formatter)
|
| 115 |
alarm_logger.addHandler(alarm_file_handler)
|
| 116 |
|
| 117 |
debug_file_handler = RotatingFileHandler(
|
|
|
|
| 121 |
backupCount=2,
|
| 122 |
)
|
| 123 |
debug_file_handler.setLevel(logging.DEBUG)
|
| 124 |
+
debug_file_handler.setFormatter(formatter)
|
| 125 |
|
| 126 |
info_file_handler = RotatingFileHandler(
|
| 127 |
filename=os.path.join(log_directory, "info.log"),
|
|
|
|
| 130 |
backupCount=2,
|
| 131 |
)
|
| 132 |
info_file_handler.setLevel(logging.INFO)
|
| 133 |
+
info_file_handler.setFormatter(formatter)
|
| 134 |
|
| 135 |
error_file_handler = RotatingFileHandler(
|
| 136 |
filename=os.path.join(log_directory, "error.log"),
|
|
|
|
| 139 |
backupCount=2,
|
| 140 |
)
|
| 141 |
error_file_handler.setLevel(logging.ERROR)
|
| 142 |
+
error_file_handler.setFormatter(formatter)
|
| 143 |
|
| 144 |
logging.basicConfig(
|
| 145 |
level=logging.DEBUG,
|
main.py
CHANGED
|
@@ -25,8 +25,10 @@ from project_settings import environment, project_path, log_directory, time_zone
|
|
| 25 |
from toolbox.os.command import Command
|
| 26 |
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
|
| 27 |
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
|
|
|
|
| 28 |
from toolbox.torchaudio.utils.visualization import process_speech_probs
|
| 29 |
from toolbox.vad.utils import PostProcess
|
|
|
|
| 30 |
|
| 31 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
| 32 |
|
|
@@ -93,9 +95,11 @@ def shell(cmd: str):
|
|
| 93 |
|
| 94 |
|
| 95 |
def get_infer_cls_by_model_name(model_name: str):
|
| 96 |
-
if model_name.__contains__("
|
|
|
|
|
|
|
| 97 |
infer_cls = InferenceFSMNVadOnnx
|
| 98 |
-
elif model_name.__contains__("silero"):
|
| 99 |
infer_cls = InferenceSileroVad
|
| 100 |
else:
|
| 101 |
raise AssertionError
|
|
@@ -158,8 +162,8 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
|
|
| 158 |
vad_info = infer_engine.infer(audio)
|
| 159 |
time_cost = time.time() - begin
|
| 160 |
|
| 161 |
-
probs = vad_info["probs"]
|
| 162 |
-
lsnr = vad_info["lsnr"]
|
| 163 |
# lsnr = lsnr / np.max(np.abs(lsnr))
|
| 164 |
lsnr = lsnr / 30
|
| 165 |
|
|
@@ -197,13 +201,17 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
|
|
| 197 |
] for v in vad_segments
|
| 198 |
]
|
| 199 |
|
|
|
|
|
|
|
|
|
|
| 200 |
# message
|
| 201 |
rtf = time_cost / audio_duration
|
| 202 |
info = {
|
| 203 |
"vad_segments": vad_segments,
|
| 204 |
"time_cost": round(time_cost, 4),
|
| 205 |
"duration": round(audio_duration, 4),
|
| 206 |
-
"rtf": round(rtf, 4)
|
|
|
|
| 207 |
}
|
| 208 |
message = json.dumps(info, ensure_ascii=False, indent=4)
|
| 209 |
|
|
@@ -239,8 +247,8 @@ def main():
|
|
| 239 |
}
|
| 240 |
for filename in (project_path / "trained_models").glob("*.zip")
|
| 241 |
if filename.name not in (
|
| 242 |
-
"cnn-vad-by-webrtcvad-nx-dns3.zip",
|
| 243 |
-
"fsmn-vad-by-webrtcvad-nx-dns3.zip",
|
| 244 |
"examples.zip",
|
| 245 |
"sound-2-ch32.zip",
|
| 246 |
"sound-3-ch32.zip",
|
|
|
|
| 25 |
from toolbox.os.command import Command
|
| 26 |
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
|
| 27 |
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
|
| 28 |
+
from toolbox.torchaudio.models.vad.native_silero_vad.inference_native_silero_vad_onnx import InferenceNativeSileroVadOnnx
|
| 29 |
from toolbox.torchaudio.utils.visualization import process_speech_probs
|
| 30 |
from toolbox.vad.utils import PostProcess
|
| 31 |
+
from toolbox.pydub.volume import get_volume
|
| 32 |
|
| 33 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
| 34 |
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def get_infer_cls_by_model_name(model_name: str):
|
| 98 |
+
if model_name.__contains__("native_silero_vad"):
|
| 99 |
+
infer_cls = InferenceNativeSileroVadOnnx
|
| 100 |
+
elif model_name.__contains__("fsmn-vad"):
|
| 101 |
infer_cls = InferenceFSMNVadOnnx
|
| 102 |
+
elif model_name.__contains__("silero-vad"):
|
| 103 |
infer_cls = InferenceSileroVad
|
| 104 |
else:
|
| 105 |
raise AssertionError
|
|
|
|
| 162 |
vad_info = infer_engine.infer(audio)
|
| 163 |
time_cost = time.time() - begin
|
| 164 |
|
| 165 |
+
probs: np.ndarray = vad_info["probs"]
|
| 166 |
+
lsnr: np.ndarray = vad_info["lsnr"]
|
| 167 |
# lsnr = lsnr / np.max(np.abs(lsnr))
|
| 168 |
lsnr = lsnr / 30
|
| 169 |
|
|
|
|
| 201 |
] for v in vad_segments
|
| 202 |
]
|
| 203 |
|
| 204 |
+
# volume
|
| 205 |
+
volume_map: dict = get_volume(audio, sample_rate)
|
| 206 |
+
|
| 207 |
# message
|
| 208 |
rtf = time_cost / audio_duration
|
| 209 |
info = {
|
| 210 |
"vad_segments": vad_segments,
|
| 211 |
"time_cost": round(time_cost, 4),
|
| 212 |
"duration": round(audio_duration, 4),
|
| 213 |
+
"rtf": round(rtf, 4),
|
| 214 |
+
**volume_map
|
| 215 |
}
|
| 216 |
message = json.dumps(info, ensure_ascii=False, indent=4)
|
| 217 |
|
|
|
|
| 247 |
}
|
| 248 |
for filename in (project_path / "trained_models").glob("*.zip")
|
| 249 |
if filename.name not in (
|
| 250 |
+
# "cnn-vad-by-webrtcvad-nx-dns3.zip",
|
| 251 |
+
# "fsmn-vad-by-webrtcvad-nx-dns3.zip",
|
| 252 |
"examples.zip",
|
| 253 |
"sound-2-ch32.zip",
|
| 254 |
"sound-3-ch32.zip",
|
toolbox/pydub/volume.py
CHANGED
|
@@ -76,6 +76,45 @@ def set_volume(waveform: np.ndarray, sample_rate: int = 8000, volume: int = 0):
|
|
| 76 |
return samples
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def get_args():
|
| 80 |
parser = argparse.ArgumentParser()
|
| 81 |
parser.add_argument(
|
|
|
|
| 76 |
return samples
|
| 77 |
|
| 78 |
|
| 79 |
+
def get_volume(waveform: np.ndarray, sample_rate: int = 8000):
|
| 80 |
+
if np.min(waveform) < -1 or np.max(waveform) > 1:
|
| 81 |
+
raise AssertionError(f"waveform type: {type(waveform)}, dtype: {waveform.dtype}")
|
| 82 |
+
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 83 |
+
raw_data = waveform.tobytes()
|
| 84 |
+
|
| 85 |
+
audio_segment = AudioSegment(
|
| 86 |
+
data=raw_data,
|
| 87 |
+
sample_width=2,
|
| 88 |
+
frame_rate=sample_rate,
|
| 89 |
+
channels=1
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
map_list = [
|
| 93 |
+
[0, -150],
|
| 94 |
+
[10, -40],
|
| 95 |
+
[50, -12],
|
| 96 |
+
[75, -6],
|
| 97 |
+
[100, 0],
|
| 98 |
+
]
|
| 99 |
+
scores = [a for a, b in map_list]
|
| 100 |
+
stages = [b for a, b in map_list]
|
| 101 |
+
|
| 102 |
+
audio_dbfs = audio_segment.dBFS
|
| 103 |
+
|
| 104 |
+
# 计算目标 volume
|
| 105 |
+
volume = score_transform(
|
| 106 |
+
x=audio_dbfs,
|
| 107 |
+
stages=list(reversed(stages)),
|
| 108 |
+
scores=list(reversed(scores)),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
result = {
|
| 112 |
+
"dbfs": audio_dbfs,
|
| 113 |
+
"volume": volume,
|
| 114 |
+
}
|
| 115 |
+
return result
|
| 116 |
+
|
| 117 |
+
|
| 118 |
def get_args():
|
| 119 |
parser = argparse.ArgumentParser()
|
| 120 |
parser.add_argument(
|
toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py
CHANGED
|
@@ -139,8 +139,9 @@ class VadPaddingJsonlDataset(IterableDataset):
|
|
| 139 |
speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
| 140 |
|
| 141 |
# volume enhancement
|
| 142 |
-
|
| 143 |
-
|
|
|
|
| 144 |
|
| 145 |
noise_wave_list = list()
|
| 146 |
for noise in noise_list:
|
|
|
|
| 139 |
speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
| 140 |
|
| 141 |
# volume enhancement
|
| 142 |
+
if self.do_volume_enhancement:
|
| 143 |
+
volume = random.randint(10, 80)
|
| 144 |
+
speech_wave_np = set_volume(speech_wave_np, sample_rate=self.expected_sample_rate, volume=volume)
|
| 145 |
|
| 146 |
noise_wave_list = list()
|
| 147 |
for noise in noise_list:
|
toolbox/torchaudio/models/vad/native_silero_vad/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
pass
|
toolbox/torchaudio/models/vad/native_silero_vad/inference_native_silero_vad_onnx.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import shutil
|
| 7 |
+
import tempfile
|
| 8 |
+
import zipfile
|
| 9 |
+
|
| 10 |
+
from scipy.io import wavfile
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import onnxruntime as ort
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
|
| 16 |
+
torch.set_num_threads(1)
|
| 17 |
+
|
| 18 |
+
from project_settings import project_path
|
| 19 |
+
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
| 20 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("toolbox")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NativeSileroVadConfig(PretrainedConfig):
|
| 27 |
+
def __init__(self,
|
| 28 |
+
sample_rate: int = 8000,
|
| 29 |
+
win_size: int = 256,
|
| 30 |
+
hop_size: int = 256,
|
| 31 |
+
**kwargs
|
| 32 |
+
):
|
| 33 |
+
super(NativeSileroVadConfig, self).__init__(**kwargs)
|
| 34 |
+
# transform
|
| 35 |
+
self.sample_rate = sample_rate
|
| 36 |
+
self.win_size = win_size
|
| 37 |
+
self.hop_size = hop_size
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class InferenceNativeSileroVadOnnx(object):
|
| 41 |
+
"""
|
| 42 |
+
code:
|
| 43 |
+
https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/utils_vad.py
|
| 44 |
+
|
| 45 |
+
model:
|
| 46 |
+
https://github.com/snakers4/silero-vad/tree/master/src/silero_vad/data
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self,
|
| 49 |
+
pretrained_model_path_or_zip_file: str,
|
| 50 |
+
device: str = "cpu"
|
| 51 |
+
):
|
| 52 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
| 53 |
+
self.device = torch.device(device)
|
| 54 |
+
|
| 55 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
| 56 |
+
config, ort_session = self.load_models(self.pretrained_model_path_or_zip_file)
|
| 57 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
| 58 |
+
|
| 59 |
+
self.config = config
|
| 60 |
+
self.ort_session = ort_session
|
| 61 |
+
|
| 62 |
+
def load_models(self, model_path: str):
|
| 63 |
+
model_path = Path(model_path)
|
| 64 |
+
if model_path.name.endswith(".zip"):
|
| 65 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
| 66 |
+
out_root = Path(tempfile.gettempdir()) / "cc_vad"
|
| 67 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
f_zip.extractall(path=out_root)
|
| 69 |
+
model_path = out_root / model_path.stem
|
| 70 |
+
|
| 71 |
+
config = NativeSileroVadConfig.from_pretrained(
|
| 72 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
opts = ort.SessionOptions()
|
| 76 |
+
opts.inter_op_num_threads = 1
|
| 77 |
+
opts.intra_op_num_threads = 1
|
| 78 |
+
|
| 79 |
+
ort_session = ort.InferenceSession(
|
| 80 |
+
(model_path / "silero_vad.onnx").as_posix(),
|
| 81 |
+
sess_options=opts
|
| 82 |
+
)
|
| 83 |
+
shutil.rmtree(model_path)
|
| 84 |
+
return config, ort_session
|
| 85 |
+
|
| 86 |
+
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
if signal.dim() == 2:
|
| 88 |
+
signal = torch.unsqueeze(signal, dim=1)
|
| 89 |
+
_, _, n_samples = signal.shape
|
| 90 |
+
remainder = (n_samples - self.config.win_size) % self.config.hop_size
|
| 91 |
+
if remainder > 0:
|
| 92 |
+
n_samples_pad = self.config.hop_size - remainder
|
| 93 |
+
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
|
| 94 |
+
return signal
|
| 95 |
+
|
| 96 |
+
def forward_chunk(self, chunk: torch.Tensor, context: torch.Tensor, state: torch.Tensor):
|
| 97 |
+
# chunk shape: [1, chunk_size]
|
| 98 |
+
num_samples = 512 if self.config.sample_rate == 16000 else 256
|
| 99 |
+
if chunk.shape[-1] != num_samples:
|
| 100 |
+
raise ValueError(f"Provided number of samples is {chunk.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
| 101 |
+
|
| 102 |
+
context_size = 64 if self.config.sample_rate == 16000 else 32
|
| 103 |
+
|
| 104 |
+
chunk = torch.cat(tensors=[context, chunk], dim=1)
|
| 105 |
+
input_feed = {
|
| 106 |
+
"input": chunk.numpy(),
|
| 107 |
+
"state": state.numpy(),
|
| 108 |
+
"sr": np.array(self.config.sample_rate, dtype=np.int64)
|
| 109 |
+
}
|
| 110 |
+
ort_outs = self.ort_session.run(output_names=None, input_feed=input_feed)
|
| 111 |
+
vad_flag, state = ort_outs
|
| 112 |
+
# vad_flag shape: [b, 1]
|
| 113 |
+
# state shape: [2, b, 128]
|
| 114 |
+
vad_flag = torch.from_numpy(vad_flag)
|
| 115 |
+
state = torch.from_numpy(state)
|
| 116 |
+
context = chunk[..., -context_size:]
|
| 117 |
+
return vad_flag, context, state
|
| 118 |
+
|
| 119 |
+
def infer(self, signal: np.ndarray) -> np.ndarray:
|
| 120 |
+
# signal shape: [num_samples,], value between -1 and 1.
|
| 121 |
+
inputs = torch.tensor(signal, dtype=torch.float32)
|
| 122 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
| 123 |
+
# inputs shape: [1, num_samples]
|
| 124 |
+
|
| 125 |
+
n_samples = inputs.shape[-1]
|
| 126 |
+
inputs = self.signal_prepare(inputs)
|
| 127 |
+
# inputs shape: [1, 1, num_samples]
|
| 128 |
+
inputs = torch.squeeze(inputs, dim=1)
|
| 129 |
+
# inputs shape: [1, num_samples]
|
| 130 |
+
_, num_samples = inputs.shape
|
| 131 |
+
|
| 132 |
+
vad_flags = list()
|
| 133 |
+
|
| 134 |
+
context = torch.zeros(0)
|
| 135 |
+
state = torch.zeros(size=(2, 1, 128), dtype=torch.float32)
|
| 136 |
+
for i in range(0, num_samples, self.config.hop_size):
|
| 137 |
+
sub_inputs = inputs[:, i:i+self.config.win_size]
|
| 138 |
+
vad_flag, context, state = self.forward_chunk(sub_inputs, context, state)
|
| 139 |
+
vad_flags.append(vad_flag)
|
| 140 |
+
|
| 141 |
+
vad_flags = torch.cat(vad_flags, dim=1).cpu()
|
| 142 |
+
# vad_flags, torch.Tensor, shape: [b, num_chunks]
|
| 143 |
+
vad_flags = vad_flags.numpy()
|
| 144 |
+
# vad_flags, np.ndarray, shape: [b, num_chunks]
|
| 145 |
+
vad_flags = vad_flags[0]
|
| 146 |
+
# vad_flags shape: [num_chunk,]
|
| 147 |
+
|
| 148 |
+
result = {
|
| 149 |
+
"probs": vad_flags,
|
| 150 |
+
"lsnr": np.zeros_like(vad_flags),
|
| 151 |
+
}
|
| 152 |
+
return result
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_args():
|
| 156 |
+
parser = argparse.ArgumentParser()
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--wav_file",
|
| 159 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
|
| 160 |
+
default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
| 161 |
+
type=str,
|
| 162 |
+
)
|
| 163 |
+
args = parser.parse_args()
|
| 164 |
+
return args
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
SAMPLE_RATE = 8000
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def main():
|
| 171 |
+
args = get_args()
|
| 172 |
+
|
| 173 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
| 174 |
+
if SAMPLE_RATE != sample_rate:
|
| 175 |
+
raise AssertionError
|
| 176 |
+
signal = signal / (1 << 15)
|
| 177 |
+
|
| 178 |
+
infer = InferenceNativeSileroVadOnnx(
|
| 179 |
+
pretrained_model_path_or_zip_file=(project_path / "trained_models/native_silero_vad.zip").as_posix(),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
vad_info = infer.infer(signal)
|
| 183 |
+
speech_probs = vad_info["probs"]
|
| 184 |
+
# speech_probs, np.ndarray shape: [num_chunk,]
|
| 185 |
+
|
| 186 |
+
speech_probs = process_speech_probs(
|
| 187 |
+
signal=signal,
|
| 188 |
+
speech_probs=speech_probs,
|
| 189 |
+
frame_step=infer.config.hop_size,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# plot
|
| 193 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py
CHANGED
|
@@ -109,9 +109,6 @@ class InferenceSileroVadOnnx(object):
|
|
| 109 |
}
|
| 110 |
return result
|
| 111 |
|
| 112 |
-
def post_process(self, probs: List[float]):
|
| 113 |
-
return
|
| 114 |
-
|
| 115 |
|
| 116 |
def get_args():
|
| 117 |
parser = argparse.ArgumentParser()
|
|
@@ -157,7 +154,7 @@ def main():
|
|
| 157 |
raise AssertionError
|
| 158 |
signal = signal / (1 << 15)
|
| 159 |
|
| 160 |
-
infer =
|
| 161 |
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
| 162 |
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
| 163 |
)
|
|
|
|
| 109 |
}
|
| 110 |
return result
|
| 111 |
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def get_args():
|
| 114 |
parser = argparse.ArgumentParser()
|
|
|
|
| 154 |
raise AssertionError
|
| 155 |
signal = signal / (1 << 15)
|
| 156 |
|
| 157 |
+
infer = InferenceSileroVadOnnx(
|
| 158 |
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
| 159 |
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
| 160 |
)
|