captcha_solver/src/__main__.py

120 lines
4.1 KiB
Python

import keras
from keras.src.layers import StringLookup
import tensorflow as tf
from keras.src import ops
import numpy as np
from matplotlib import pyplot as plt
from src.load import download_dataset, load_datasets
from src.config import config
from src.model import build_model
from src.tools import split_data, decode_batch_predictions, preprocess_dataset
def main():
print("Downloading dataset...")
download_dataset()
print("Dataset downloaded, preprocessing it...")
preprocess_dataset()
print("Dataset preprocessed")
# return 0
images, labels, characters = load_datasets()
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
def encode_single_sample(img_path, label):
imag = tf.io.read_file(img_path)
imag = tf.io.decode_png(imag, channels=1)
imag = tf.image.convert_image_dtype(imag, tf.float32)
imag = ops.image.resize(imag, [config.IMAGE_HEIGHT, config.IMAGE_WIDTH])
imag = ops.transpose(imag, axes=[1, 0, 2])
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
return {"image": imag, "label": label}
max_length = max([len(label) for label in labels])
x_train, y_train, x_valid, y_valid = split_data(np.array(images), np.array(labels))
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = (
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
.batch(config.BATCH_SIZE)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
validation_dataset = (
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
.batch(config.BATCH_SIZE)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
_, ax = plt.subplots(4, 4, figsize=(10, 5))
for batch in train_dataset.take(1):
images = batch["image"]
labels = batch["label"]
for i in range(16):
img = (images[i] * 255).numpy().astype("uint8")
label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode("utf-8")
ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray")
ax[i // 4, i % 4].set_title(label)
ax[i // 4, i % 4].axis("off")
plt.show()
model = build_model(char_to_num)
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_loss", patience=10, restore_best_weights=True
)
history = model.fit(
train_dataset,
validation_data=validation_dataset,
epochs=config.EPOCHS,
# steps_per_epoch=math.ceil(len(train_dataset) / config.BATCH_SIZE),
callbacks = [early_stopping],
)s=2
f=9
s+f
# plt.plot(history.history['loss'])
# plt.plot(history.history['val_loss'])
# plt.title('model loss')
# plt.ylabel('loss')
# plt.xlabel('epoch')
# plt.legend(['train', 'val'], loc='upper left')
# plt.show()
prediction_model = keras.models.Model(
model.input[0], model.get_layer(name="dense2").output
)
prediction_model.summary()
for batch in validation_dataset.take(1):
batch_images = batch["image"]
batch_labels = batch["label"]
preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds, max_length, num_to_char)
orig_texts = []
for label in batch_labels:
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
orig_texts.append(label)
_, ax = plt.subplots(4, 4, figsize=(15, 5))
for i in range(len(pred_texts)):
img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)
img = img.T
title = f"Prediction: {pred_texts[i]}"
ax[i // 4, i % 4].imshow(img, cmap="gray")
ax[i // 4, i % 4].set_title(title)
ax[i // 4, i % 4].axis("off")
plt.show()
model.save("captcha_solver.keras")
if __name__ == "__main__":
main()