""" Title: Audio Classification with the STFTSpectrogram layer Author: [Mostafa M. Amin](https://mostafa-amin.com) Date created: 2024/10/04 Last modified: 2024/10/04 Description: Introducing the `STFTSpectrogram` layer to extract spectrograms for audio classification. Accelerator: GPU """ """ ## Introduction Preprocessing audio as spectrograms is an essential step in the vast majority of audio-based applications. Spectrograms represent the frequency content of a signal over time, are widely used for this purpose. In this tutorial, we'll demonstrate how to use the `STFTSpectrogram` layer in Keras to convert raw audio waveforms into spectrograms **within the model**. We'll then feed these spectrograms into an LSTM network followed by Dense layers to perform audio classification on the Speech Commands dataset. We will: - Load the ESC-10 dataset. - Preprocess the raw audio waveforms and generate spectrograms using `STFTSpectrogram`. - Build two models, one using spectrograms as 1D signals and the other is using as images (2D signals) with a pretrained image model. - Train and evaluate the models. ## Setup ### Importing the necessary libraries """ import os os.environ["KERAS_BACKEND"] = "jax" import keras import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.io.wavfile from keras import layers from scipy.signal import resample keras.utils.set_random_seed(41) """ ### Define some variables """ BASE_DATA_DIR = "./datasets/esc-50_extracted/ESC-50-master/" BATCH_SIZE = 16 NUM_CLASSES = 10 EPOCHS = 200 SAMPLE_RATE = 16000 """ ## Download and Preprocess the ESC-10 Dataset We'll use the Dataset for Environmental Sound Classification dataset (ESC-10). This dataset consists of five-second .wav files of environmental sounds. ### Download and Extract the dataset """ keras.utils.get_file( "esc-50.zip", "https://github.com/karoldvl/ESC-50/archive/master.zip", cache_dir="./", cache_subdir="datasets", extract=True, ) """ ### Read the CSV file """ pd_data = pd.read_csv(os.path.join(BASE_DATA_DIR, "meta", "esc50.csv")) # filter ESC-50 to ESC-10 and reassign the targets pd_data = pd_data[pd_data["esc10"]] targets = sorted(pd_data["target"].unique().tolist()) assert len(targets) == NUM_CLASSES old_target_to_new_target = {old: new for new, old in enumerate(targets)} pd_data["target"] = pd_data["target"].map(lambda t: old_target_to_new_target[t]) pd_data """ ### Define functions to read and preprocess the WAV files """ def read_wav_file(path, target_sr=SAMPLE_RATE): sr, wav = scipy.io.wavfile.read(os.path.join(BASE_DATA_DIR, "audio", path)) wav = wav.astype(np.float32) / 32768.0 # normalize to [-1, 1] num_samples = int(len(wav) * target_sr / sr) # resample to 16 kHz wav = resample(wav, num_samples) return wav[:, None] # Add a channel dimension (of size 1) """ Create a function that uses the `STFTSpectrogram` to compute a spectrogram, then plots it. """ def plot_single_spectrogram(sample_wav_data): spectrogram = layers.STFTSpectrogram( mode="log", frame_length=SAMPLE_RATE * 20 // 1000, frame_step=SAMPLE_RATE * 5 // 1000, fft_length=1024, trainable=False, )(sample_wav_data[None, ...])[0, ...] # Plot the spectrogram plt.imshow(spectrogram.T, origin="lower") plt.title("Single Channel Spectrogram") plt.xlabel("Time") plt.ylabel("Frequency") plt.show() """ Create a function that uses the `STFTSpectrogram` to compute three spectrograms with multiple bandwidths, then aligns them as an image with different channels, to get a multi-bandwith spectrogram, then plots the spectrogram. """ def plot_multi_bandwidth_spectrogram(sample_wav_data): # All spectrograms must use the same `fft_length`, `frame_step`, and # `padding="same"` in order to produce spectrograms with identical shapes, # hence aligning them together. `expand_dims` ensures that the shapes are # compatible with image models. spectrograms = np.concatenate( [ layers.STFTSpectrogram( mode="log", frame_length=SAMPLE_RATE * x // 1000, frame_step=SAMPLE_RATE * 5 // 1000, fft_length=1024, padding="same", expand_dims=True, )(sample_wav_data[None, ...])[0, ...] for x in [5, 10, 20] ], axis=-1, ).transpose([1, 0, 2]) # normalize each color channel for better viewing mn = spectrograms.min(axis=(0, 1), keepdims=True) mx = spectrograms.max(axis=(0, 1), keepdims=True) spectrograms = (spectrograms - mn) / (mx - mn) plt.imshow(spectrograms, origin="lower") plt.title("Multi-bandwidth Spectrogram") plt.xlabel("Time") plt.ylabel("Frequency") plt.show() """ Demonstrate a sample wav file. """ sample_wav_data = read_wav_file(pd_data["filename"].tolist()[52]) plt.plot(sample_wav_data[:, 0]) plt.show() """ Plot a Spectrogram """ plot_single_spectrogram(sample_wav_data) """ Plot a multi-bandwidth spectrogram """ plot_multi_bandwidth_spectrogram(sample_wav_data) """ ### Define functions to construct a TF Dataset """ def read_dataset(df, folds): msk = df["fold"].isin(folds) filenames = df["filename"][msk] targets = df["target"][msk].values waves = np.array([read_wav_file(fil) for fil in filenames], dtype=np.float32) return waves, targets """ ### Create the datasets """ train_x, train_y = read_dataset(pd_data, [1, 2, 3]) valid_x, valid_y = read_dataset(pd_data, [4]) test_x, test_y = read_dataset(pd_data, [5]) """ ## Training the Models In this tutorial we demonstrate the different usecases of the `STFTSpectrogram` layer. The first model will use a non-trainable `STFTSpectrogram` layer, so it is intended purely for preprocessing. Additionally, the model will use 1D signals, hence it make use of Conv1D layers. The second model will use a trainable `STFTSpectrogram` layer with the `expand_dims` option, which expands the shapes to be compatible with image models. ### Create the 1D model 1. Create a non-trainable spectrograms, extracting a 1D time signal. 2. Apply `Conv1D` layers with `LayerNormalization` simialar to the classic VGG design. 4. Apply global maximum pooling to have fixed set of features. 5. Add `Dense` layers to make the final predictions based on the features. """ model1d = keras.Sequential( [ layers.InputLayer((None, 1)), layers.STFTSpectrogram( mode="log", frame_length=SAMPLE_RATE * 40 // 1000, frame_step=SAMPLE_RATE * 15 // 1000, trainable=False, ), layers.Conv1D(64, 64, activation="relu"), layers.Conv1D(128, 16, activation="relu"), layers.LayerNormalization(), layers.MaxPooling1D(4), layers.Conv1D(128, 8, activation="relu"), layers.Conv1D(256, 8, activation="relu"), layers.Conv1D(512, 4, activation="relu"), layers.LayerNormalization(), layers.Dropout(0.5), layers.GlobalMaxPooling1D(), layers.Dense(256, activation="relu"), layers.Dense(256, activation="relu"), layers.Dropout(0.5), layers.Dense(NUM_CLASSES, activation="softmax"), ], name="model_1d_non_trainble_stft", ) model1d.compile( optimizer=keras.optimizers.Adam(1e-5), loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) model1d.summary() """ Train the model and restore the best weights. """ history_model1d = model1d.fit( train_x, train_y, batch_size=BATCH_SIZE, validation_data=(valid_x, valid_y), epochs=EPOCHS, callbacks=[ keras.callbacks.EarlyStopping( monitor="val_loss", patience=EPOCHS, restore_best_weights=True, ) ], ) """ ### Create the 2D model 1. Create three spectrograms with multiple band-widths from the raw input. 2. Concatenate the three spectrograms to have three channels. 3. Load `MobileNet` and set the weights from the weights trained on `ImageNet`. 4. Apply global maximum pooling to have fixed set of features. 5. Add `Dense` layers to make the final predictions based on the features. """ input = layers.Input((None, 1)) spectrograms = [ layers.STFTSpectrogram( mode="log", frame_length=SAMPLE_RATE * frame_size // 1000, frame_step=SAMPLE_RATE * 15 // 1000, fft_length=2048, padding="same", expand_dims=True, # trainable=True, # trainable by default )(input) for frame_size in [30, 40, 50] # frame size in milliseconds ] multi_spectrograms = layers.Concatenate(axis=-1)(spectrograms) img_model = keras.applications.MobileNet(include_top=False, pooling="max") output = img_model(multi_spectrograms) output = layers.Dropout(0.5)(output) output = layers.Dense(256, activation="relu")(output) output = layers.Dense(256, activation="relu")(output) output = layers.Dense(NUM_CLASSES, activation="softmax")(output) model2d = keras.Model(input, output, name="model_2d_trainble_stft") model2d.compile( optimizer=keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) model2d.summary() """ Train the model and restore the best weights. """ history_model2d = model2d.fit( train_x, train_y, batch_size=BATCH_SIZE, validation_data=(valid_x, valid_y), epochs=EPOCHS, callbacks=[ keras.callbacks.EarlyStopping( monitor="val_loss", patience=EPOCHS, restore_best_weights=True, ) ], ) """ ### Plot Training History """ epochs_range = range(EPOCHS) plt.figure(figsize=(14, 5)) plt.subplot(1, 2, 1) plt.plot( epochs_range, history_model1d.history["accuracy"], label="Training Accuracy,1D model with non-trainable STFT", ) plt.plot( epochs_range, history_model1d.history["val_accuracy"], label="Validation Accuracy, 1D model with non-trainable STFT", ) plt.plot( epochs_range, history_model2d.history["accuracy"], label="Training Accuracy, 2D model with trainable STFT", ) plt.plot( epochs_range, history_model2d.history["val_accuracy"], label="Validation Accuracy, 2D model with trainable STFT", ) plt.legend(loc="lower right") plt.title("Training and Validation Accuracy") plt.subplot(1, 2, 2) plt.plot( epochs_range, history_model1d.history["loss"], label="Training Loss,1D model with non-trainable STFT", ) plt.plot( epochs_range, history_model1d.history["val_loss"], label="Validation Loss, 1D model with non-trainable STFT", ) plt.plot( epochs_range, history_model2d.history["loss"], label="Training Loss, 2D model with trainable STFT", ) plt.plot( epochs_range, history_model2d.history["val_loss"], label="Validation Loss, 2D model with trainable STFT", ) plt.legend(loc="upper right") plt.title("Training and Validation Loss") plt.show() """ ### Evaluate on Test Data Running the models on the test set. """ _, test_acc = model1d.evaluate(test_x, test_y) print(f"1D model wit non-trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%") _, test_acc = model2d.evaluate(test_x, test_y) print(f"2D model with trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%")