whisper-hallucination-detection / custom_objects.py
avcton's picture
initial upload: model + custom objects
a29db06 verified
import tensorflow as tf
from tensorflow.keras.layers import Layer
# This is necessary to add a feature dimension to the scores input before TimeDistributed(Dense)
class ExpandDimsLayer(Layer):
"""
A custom Keras layer to wrap tf.expand_dims for use in the Functional API.
Expands the dimensions of the input tensor at a specified axis.
"""
def __init__(self, axis=-1, **kwargs):
super().__init__(**kwargs)
self.axis = axis
def call(self, inputs):
return tf.expand_dims(inputs, self.axis)
def get_config(self):
config = super().get_config()
config.update({"axis": self.axis})
return config
# Custom Loss Function to Handle Padded Labels
# This function masks out the loss contribution from padded labels (-1).
def masked_binary_crossentropy(y_true, y_pred):
# y_true comes in as (batch_size, max_segment_len)
# y_pred comes in as (batch_size, max_segment_len, 1) from the model's output layer
# Ensure y_pred has the correct shape (batch_size, max_segment_len, 1)
# If it was squeezed, expand it.
if len(y_pred.shape) == 2:
y_pred = tf.expand_dims(y_pred, -1)
# Ensure y_true also has the correct shape (batch_size, max_segment_len, 1)
y_true_reshaped = tf.expand_dims(tf.cast(y_true, tf.float32), -1)
# Clip y_pred values to avoid log(0) or log(1) issues, which lead to negative/infinite loss
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
# Manually calculate element-wise binary crossentropy
# Both y_true_reshaped and y_pred are now (batch_size, max_segment_len, 1)
bce_per_element = - (y_true_reshaped * tf.math.log(y_pred) +
(1.0 - y_true_reshaped) * tf.math.log(1.0 - y_pred))
# Create the mask, expanded to match the shape of bce_per_element
mask = tf.cast(tf.not_equal(y_true, -1), tf.float32) # Original mask shape: (batch_size, max_segment_len)
mask_expanded = tf.expand_dims(mask, -1) # Expanded mask shape: (batch_size, max_segment_len, 1)
# Apply the mask: set loss to 0 for padded elements
# Both `bce_per_element` and `mask_expanded` are (batch_size, max_segment_len, 1)
masked_bce = bce_per_element * mask_expanded
# Calculate the sum of the masked loss and divide by the sum of the mask
# This gives the average loss only over the non-padded elements.
# Add epsilon to prevent division by zero in case of batches with all padded elements.
sum_mask = tf.reduce_sum(mask_expanded)
return tf.reduce_sum(masked_bce) / (sum_mask + epsilon)