| import tensorflow as tf | |
| from tensorflow.keras.layers import Layer | |
| # This is necessary to add a feature dimension to the scores input before TimeDistributed(Dense) | |
| class ExpandDimsLayer(Layer): | |
| """ | |
| A custom Keras layer to wrap tf.expand_dims for use in the Functional API. | |
| Expands the dimensions of the input tensor at a specified axis. | |
| """ | |
| def __init__(self, axis=-1, **kwargs): | |
| super().__init__(**kwargs) | |
| self.axis = axis | |
| def call(self, inputs): | |
| return tf.expand_dims(inputs, self.axis) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"axis": self.axis}) | |
| return config | |
| # Custom Loss Function to Handle Padded Labels | |
| # This function masks out the loss contribution from padded labels (-1). | |
| def masked_binary_crossentropy(y_true, y_pred): | |
| # y_true comes in as (batch_size, max_segment_len) | |
| # y_pred comes in as (batch_size, max_segment_len, 1) from the model's output layer | |
| # Ensure y_pred has the correct shape (batch_size, max_segment_len, 1) | |
| # If it was squeezed, expand it. | |
| if len(y_pred.shape) == 2: | |
| y_pred = tf.expand_dims(y_pred, -1) | |
| # Ensure y_true also has the correct shape (batch_size, max_segment_len, 1) | |
| y_true_reshaped = tf.expand_dims(tf.cast(y_true, tf.float32), -1) | |
| # Clip y_pred values to avoid log(0) or log(1) issues, which lead to negative/infinite loss | |
| epsilon = tf.keras.backend.epsilon() | |
| y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon) | |
| # Manually calculate element-wise binary crossentropy | |
| # Both y_true_reshaped and y_pred are now (batch_size, max_segment_len, 1) | |
| bce_per_element = - (y_true_reshaped * tf.math.log(y_pred) + | |
| (1.0 - y_true_reshaped) * tf.math.log(1.0 - y_pred)) | |
| # Create the mask, expanded to match the shape of bce_per_element | |
| mask = tf.cast(tf.not_equal(y_true, -1), tf.float32) # Original mask shape: (batch_size, max_segment_len) | |
| mask_expanded = tf.expand_dims(mask, -1) # Expanded mask shape: (batch_size, max_segment_len, 1) | |
| # Apply the mask: set loss to 0 for padded elements | |
| # Both `bce_per_element` and `mask_expanded` are (batch_size, max_segment_len, 1) | |
| masked_bce = bce_per_element * mask_expanded | |
| # Calculate the sum of the masked loss and divide by the sum of the mask | |
| # This gives the average loss only over the non-padded elements. | |
| # Add epsilon to prevent division by zero in case of batches with all padded elements. | |
| sum_mask = tf.reduce_sum(mask_expanded) | |
| return tf.reduce_sum(masked_bce) / (sum_mask + epsilon) |