Files
captcha_solver/src/ctc.py
2025-06-23 23:45:01 +03:00

107 lines
3.6 KiB
Python

import keras
import tensorflow as tf
from keras.src import ops, layers
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
label_length = ops.cast(ops.squeeze(label_length, axis=-1), dtype="int32")
input_length = ops.cast(ops.squeeze(input_length, axis=-1), dtype="int32")
sparse_labels = ops.cast(
ctc_label_dense_to_sparse(y_true, label_length), dtype="int32"
)
y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
return ops.expand_dims(
tf.compat.v1.nn.ctc_loss(
inputs=y_pred, labels=sparse_labels, sequence_length=input_length
),
1,
)
def ctc_label_dense_to_sparse(labels, label_lengths):
label_shape = ops.shape(labels)
num_batches_tns = ops.stack([label_shape[0]])
max_num_labels_tns = ops.stack([label_shape[1]])
def range_less_than(old_input, current_input):
return ops.expand_dims(ops.arange(ops.shape(old_input)[1]), 0) < tf.fill(
max_num_labels_tns, current_input
)
init = ops.cast(tf.fill([1, label_shape[1]], 0), dtype="bool")
dense_mask = tf.compat.v1.scan(
range_less_than, label_lengths, initializer=init, parallel_iterations=1
)
dense_mask = dense_mask[:, 0, :]
label_array = ops.reshape(
ops.tile(ops.arange(0, label_shape[1]), num_batches_tns), label_shape
)
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
batch_array = ops.transpose(
ops.reshape(
ops.tile(ops.arange(0, label_shape[0]), max_num_labels_tns),
tf.reverse(label_shape, [0]),
)
)
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
indices = ops.transpose(
ops.reshape(ops.concatenate([batch_ind, label_ind], axis=0), [2, -1])
)
vals_sparse = tf.compat.v1.gather_nd(labels, indices)
return tf.SparseTensor(
ops.cast(indices, dtype="int64"),
vals_sparse,
ops.cast(label_shape, dtype="int64")
)
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
input_shape = ops.shape(y_pred)
num_samples, num_steps = input_shape[0], input_shape[1]
y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
input_length = ops.cast(input_length, dtype="int32")
if greedy:
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
inputs=y_pred, sequence_length=input_length
)
else:
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
inputs=y_pred,
sequence_length=input_length,
beam_width=beam_width,
top_paths=top_paths,
)
decoded_dense = []
for st in decoded:
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
return decoded_dense, log_prob
@keras.saving.register_keras_serializable()
class CTCLayer(layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = ctc_batch_cost
def get_config(self):
return {"name": self.name}
def call(self, y_true, y_pred):
batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
return y_pred