File size: 2,526 Bytes
a29db06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import tensorflow as tf
from tensorflow.keras.layers import Layer


# This is necessary to add a feature dimension to the scores input before TimeDistributed(Dense)
class ExpandDimsLayer(Layer):
  """
  A custom Keras layer to wrap tf.expand_dims for use in the Functional API.
  Expands the dimensions of the input tensor at a specified axis.
  """
  def __init__(self, axis=-1, **kwargs):
    super().__init__(**kwargs)
    self.axis = axis

  def call(self, inputs):
    return tf.expand_dims(inputs, self.axis)

  def get_config(self):
    config = super().get_config()
    config.update({"axis": self.axis})
    return config


# Custom Loss Function to Handle Padded Labels
# This function masks out the loss contribution from padded labels (-1).
def masked_binary_crossentropy(y_true, y_pred):
  # y_true comes in as (batch_size, max_segment_len)
  # y_pred comes in as (batch_size, max_segment_len, 1) from the model's output layer

  # Ensure y_pred has the correct shape (batch_size, max_segment_len, 1)
  # If it was squeezed, expand it.
  if len(y_pred.shape) == 2:
    y_pred = tf.expand_dims(y_pred, -1)

  # Ensure y_true also has the correct shape (batch_size, max_segment_len, 1)
  y_true_reshaped = tf.expand_dims(tf.cast(y_true, tf.float32), -1)

  # Clip y_pred values to avoid log(0) or log(1) issues, which lead to negative/infinite loss
  epsilon = tf.keras.backend.epsilon()
  y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)

  # Manually calculate element-wise binary crossentropy
  # Both y_true_reshaped and y_pred are now (batch_size, max_segment_len, 1)
  bce_per_element = - (y_true_reshaped * tf.math.log(y_pred) +
                        (1.0 - y_true_reshaped) * tf.math.log(1.0 - y_pred))

  # Create the mask, expanded to match the shape of bce_per_element
  mask = tf.cast(tf.not_equal(y_true, -1), tf.float32) # Original mask shape: (batch_size, max_segment_len)
  mask_expanded = tf.expand_dims(mask, -1) # Expanded mask shape: (batch_size, max_segment_len, 1)

  # Apply the mask: set loss to 0 for padded elements
  # Both `bce_per_element` and `mask_expanded` are (batch_size, max_segment_len, 1)
  masked_bce = bce_per_element * mask_expanded

  # Calculate the sum of the masked loss and divide by the sum of the mask
  # This gives the average loss only over the non-padded elements.
  # Add epsilon to prevent division by zero in case of batches with all padded elements.
  sum_mask = tf.reduce_sum(mask_expanded)
  return tf.reduce_sum(masked_bce) / (sum_mask + epsilon)