107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
import keras
|
|
import tensorflow as tf
|
|
from keras.src import ops, layers
|
|
|
|
|
|
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
|
|
label_length = ops.cast(ops.squeeze(label_length, axis=-1), dtype="int32")
|
|
input_length = ops.cast(ops.squeeze(input_length, axis=-1), dtype="int32")
|
|
sparse_labels = ops.cast(
|
|
ctc_label_dense_to_sparse(y_true, label_length), dtype="int32"
|
|
)
|
|
|
|
y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
|
|
|
|
return ops.expand_dims(
|
|
tf.compat.v1.nn.ctc_loss(
|
|
inputs=y_pred, labels=sparse_labels, sequence_length=input_length
|
|
),
|
|
1,
|
|
)
|
|
|
|
|
|
def ctc_label_dense_to_sparse(labels, label_lengths):
|
|
label_shape = ops.shape(labels)
|
|
num_batches_tns = ops.stack([label_shape[0]])
|
|
max_num_labels_tns = ops.stack([label_shape[1]])
|
|
|
|
def range_less_than(old_input, current_input):
|
|
return ops.expand_dims(ops.arange(ops.shape(old_input)[1]), 0) < tf.fill(
|
|
max_num_labels_tns, current_input
|
|
)
|
|
|
|
init = ops.cast(tf.fill([1, label_shape[1]], 0), dtype="bool")
|
|
dense_mask = tf.compat.v1.scan(
|
|
range_less_than, label_lengths, initializer=init, parallel_iterations=1
|
|
)
|
|
dense_mask = dense_mask[:, 0, :]
|
|
|
|
label_array = ops.reshape(
|
|
ops.tile(ops.arange(0, label_shape[1]), num_batches_tns), label_shape
|
|
)
|
|
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
|
|
|
|
batch_array = ops.transpose(
|
|
ops.reshape(
|
|
ops.tile(ops.arange(0, label_shape[0]), max_num_labels_tns),
|
|
tf.reverse(label_shape, [0]),
|
|
)
|
|
)
|
|
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
|
|
indices = ops.transpose(
|
|
ops.reshape(ops.concatenate([batch_ind, label_ind], axis=0), [2, -1])
|
|
)
|
|
|
|
vals_sparse = tf.compat.v1.gather_nd(labels, indices)
|
|
|
|
return tf.SparseTensor(
|
|
ops.cast(indices, dtype="int64"),
|
|
vals_sparse,
|
|
ops.cast(label_shape, dtype="int64")
|
|
)
|
|
|
|
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
|
input_shape = ops.shape(y_pred)
|
|
num_samples, num_steps = input_shape[0], input_shape[1]
|
|
y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
|
|
input_length = ops.cast(input_length, dtype="int32")
|
|
|
|
if greedy:
|
|
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
|
|
inputs=y_pred, sequence_length=input_length
|
|
)
|
|
else:
|
|
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
|
|
inputs=y_pred,
|
|
sequence_length=input_length,
|
|
beam_width=beam_width,
|
|
top_paths=top_paths,
|
|
)
|
|
decoded_dense = []
|
|
for st in decoded:
|
|
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
|
|
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
|
|
return decoded_dense, log_prob
|
|
|
|
|
|
@keras.saving.register_keras_serializable()
|
|
class CTCLayer(layers.Layer):
|
|
def __init__(self, name=None):
|
|
super().__init__(name=name)
|
|
self.loss_fn = ctc_batch_cost
|
|
|
|
def get_config(self):
|
|
return {"name": self.name}
|
|
|
|
def call(self, y_true, y_pred):
|
|
batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
|
|
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
|
|
label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
|
|
|
|
input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
|
|
label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
|
|
|
|
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
|
|
self.add_loss(loss)
|
|
|
|
return y_pred |