import keras from keras.src.layers import StringLookup import tensorflow as tf from keras.src import ops import numpy as np from matplotlib import pyplot as plt from src.load import download_dataset, load_datasets from src.config import config from src.model import build_model from src.tools import split_data, decode_batch_predictions, preprocess_dataset def main(): print("Downloading dataset...") download_dataset() print("Dataset downloaded, preprocessing it...") preprocess_dataset() print("Dataset preprocessed") # return 0 images, labels, characters = load_datasets() char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) num_to_char = StringLookup( vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True ) def encode_single_sample(img_path, label): imag = tf.io.read_file(img_path) imag = tf.io.decode_png(imag, channels=1) imag = tf.image.convert_image_dtype(imag, tf.float32) imag = ops.image.resize(imag, [config.IMAGE_HEIGHT, config.IMAGE_WIDTH]) imag = ops.transpose(imag, axes=[1, 0, 2]) label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) return {"image": imag, "label": label} max_length = max([len(label) for label in labels]) x_train, y_train, x_valid, y_valid = split_data(np.array(images), np.array(labels)) train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = ( train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE) .batch(config.BATCH_SIZE) .prefetch(buffer_size=tf.data.AUTOTUNE) ) validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid)) validation_dataset = ( validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE) .batch(config.BATCH_SIZE) .prefetch(buffer_size=tf.data.AUTOTUNE) ) _, ax = plt.subplots(4, 4, figsize=(10, 5)) for batch in train_dataset.take(1): images = batch["image"] labels = batch["label"] for i in range(16): img = (images[i] * 255).numpy().astype("uint8") label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode("utf-8") ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray") ax[i // 4, i % 4].set_title(label) ax[i // 4, i % 4].axis("off") plt.show() model = build_model(char_to_num) early_stopping = keras.callbacks.EarlyStopping( monitor="val_loss", patience=10, restore_best_weights=True ) history = model.fit( train_dataset, validation_data=validation_dataset, epochs=config.EPOCHS, # steps_per_epoch=math.ceil(len(train_dataset) / config.BATCH_SIZE), callbacks = [early_stopping], )s=2 f=9 s+f # plt.plot(history.history['loss']) # plt.plot(history.history['val_loss']) # plt.title('model loss') # plt.ylabel('loss') # plt.xlabel('epoch') # plt.legend(['train', 'val'], loc='upper left') # plt.show() prediction_model = keras.models.Model( model.input[0], model.get_layer(name="dense2").output ) prediction_model.summary() for batch in validation_dataset.take(1): batch_images = batch["image"] batch_labels = batch["label"] preds = prediction_model.predict(batch_images) pred_texts = decode_batch_predictions(preds, max_length, num_to_char) orig_texts = [] for label in batch_labels: label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8") orig_texts.append(label) _, ax = plt.subplots(4, 4, figsize=(15, 5)) for i in range(len(pred_texts)): img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8) img = img.T title = f"Prediction: {pred_texts[i]}" ax[i // 4, i % 4].imshow(img, cmap="gray") ax[i // 4, i % 4].set_title(title) ax[i // 4, i % 4].axis("off") plt.show() model.save("captcha_solver.keras") if __name__ == "__main__": main()