|
|
import numpy as np |
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
|
|
|
def predict_and_decode(model, transcriptions, max_segment_len, max_score_len, binary_threshold, label_mapping): |
|
|
""" |
|
|
Makes predictions on a list of transcriptions using the trained model and decodes the results. |
|
|
|
|
|
Args: |
|
|
model: The trained Keras model. |
|
|
transcriptions: A list of transcriptions, where each transcription is a list of score lists. |
|
|
Example: [[[score1_seg1, score2_seg1], [score1_seg2, score2_seg2]], [[score1_segA]]] |
|
|
max_segment_len: The maximum number of segments in a transcription (used for padding). |
|
|
max_score_len: The maximum number of scores in a segment (used for padding). |
|
|
binary_threshold: The best threshold to convert probabilities to binary predictions. |
|
|
label_mapping: The dictionary used to map labels ('H', 'NH') to integers (1, 0). |
|
|
|
|
|
Returns: |
|
|
A list of lists, where each inner list contains the predicted class labels |
|
|
('H' or 'NH') for the segments in the corresponding input transcription. Padded |
|
|
transcriptions will not have a prediction. |
|
|
""" |
|
|
padded_transcriptions = [] |
|
|
original_segment_lengths = [] |
|
|
|
|
|
for transcription in transcriptions: |
|
|
original_segment_lengths.append(len(transcription)) |
|
|
|
|
|
padded_scores = pad_sequences(transcription, maxlen=max_score_len, padding='post', dtype='float32') |
|
|
|
|
|
|
|
|
num_segments = padded_scores.shape[0] |
|
|
segment_padding_needed = max_segment_len - num_segments |
|
|
if segment_padding_needed > 0: |
|
|
segment_padding = np.zeros((segment_padding_needed, max_score_len), dtype='float32') |
|
|
padded_transcriptions.append(np.vstack((padded_scores, segment_padding))) |
|
|
else: |
|
|
padded_transcriptions.append(padded_scores) |
|
|
|
|
|
|
|
|
padded_transcriptions_np = np.array(padded_transcriptions) |
|
|
|
|
|
|
|
|
Y_pred_raw = model.predict(padded_transcriptions_np) |
|
|
|
|
|
|
|
|
Y_pred_binary = (Y_pred_raw > binary_threshold).astype(int) |
|
|
|
|
|
|
|
|
reverse_label_mapping = {v: k for k, v in label_mapping.items()} |
|
|
decoded_predictions = [] |
|
|
|
|
|
for i in range(len(transcriptions)): |
|
|
|
|
|
predictions_for_transcription = Y_pred_binary[i, :original_segment_lengths[i], 0] |
|
|
|
|
|
|
|
|
decoded_transcription_predictions = [reverse_label_mapping[pred] for pred in predictions_for_transcription] |
|
|
decoded_predictions.append(decoded_transcription_predictions) |
|
|
|
|
|
return decoded_predictions |