import tensorflow as tf from tensorflow.keras.layers import Layer # This is necessary to add a feature dimension to the scores input before TimeDistributed(Dense) class ExpandDimsLayer(Layer): """ A custom Keras layer to wrap tf.expand_dims for use in the Functional API. Expands the dimensions of the input tensor at a specified axis. """ def __init__(self, axis=-1, **kwargs): super().__init__(**kwargs) self.axis = axis def call(self, inputs): return tf.expand_dims(inputs, self.axis) def get_config(self): config = super().get_config() config.update({"axis": self.axis}) return config # Custom Loss Function to Handle Padded Labels # This function masks out the loss contribution from padded labels (-1). def masked_binary_crossentropy(y_true, y_pred): # y_true comes in as (batch_size, max_segment_len) # y_pred comes in as (batch_size, max_segment_len, 1) from the model's output layer # Ensure y_pred has the correct shape (batch_size, max_segment_len, 1) # If it was squeezed, expand it. if len(y_pred.shape) == 2: y_pred = tf.expand_dims(y_pred, -1) # Ensure y_true also has the correct shape (batch_size, max_segment_len, 1) y_true_reshaped = tf.expand_dims(tf.cast(y_true, tf.float32), -1) # Clip y_pred values to avoid log(0) or log(1) issues, which lead to negative/infinite loss epsilon = tf.keras.backend.epsilon() y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon) # Manually calculate element-wise binary crossentropy # Both y_true_reshaped and y_pred are now (batch_size, max_segment_len, 1) bce_per_element = - (y_true_reshaped * tf.math.log(y_pred) + (1.0 - y_true_reshaped) * tf.math.log(1.0 - y_pred)) # Create the mask, expanded to match the shape of bce_per_element mask = tf.cast(tf.not_equal(y_true, -1), tf.float32) # Original mask shape: (batch_size, max_segment_len) mask_expanded = tf.expand_dims(mask, -1) # Expanded mask shape: (batch_size, max_segment_len, 1) # Apply the mask: set loss to 0 for padded elements # Both `bce_per_element` and `mask_expanded` are (batch_size, max_segment_len, 1) masked_bce = bce_per_element * mask_expanded # Calculate the sum of the masked loss and divide by the sum of the mask # This gives the average loss only over the non-padded elements. # Add epsilon to prevent division by zero in case of batches with all padded elements. sum_mask = tf.reduce_sum(mask_expanded) return tf.reduce_sum(masked_bce) / (sum_mask + epsilon)