120 lines
4.1 KiB
Python
120 lines
4.1 KiB
Python
import keras
|
|
from keras.src.layers import StringLookup
|
|
import tensorflow as tf
|
|
from keras.src import ops
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
from src.load import download_dataset, load_datasets
|
|
from src.config import config
|
|
from src.model import build_model
|
|
from src.tools import split_data, decode_batch_predictions, preprocess_dataset
|
|
|
|
|
|
def main():
|
|
print("Downloading dataset...")
|
|
download_dataset()
|
|
print("Dataset downloaded, preprocessing it...")
|
|
preprocess_dataset()
|
|
print("Dataset preprocessed")
|
|
# return 0
|
|
images, labels, characters = load_datasets()
|
|
|
|
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
|
num_to_char = StringLookup(
|
|
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
|
)
|
|
|
|
def encode_single_sample(img_path, label):
|
|
imag = tf.io.read_file(img_path)
|
|
imag = tf.io.decode_png(imag, channels=1)
|
|
imag = tf.image.convert_image_dtype(imag, tf.float32)
|
|
imag = ops.image.resize(imag, [config.IMAGE_HEIGHT, config.IMAGE_WIDTH])
|
|
imag = ops.transpose(imag, axes=[1, 0, 2])
|
|
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
|
|
return {"image": imag, "label": label}
|
|
|
|
max_length = max([len(label) for label in labels])
|
|
|
|
x_train, y_train, x_valid, y_valid = split_data(np.array(images), np.array(labels))
|
|
|
|
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
|
train_dataset = (
|
|
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
|
|
.batch(config.BATCH_SIZE)
|
|
.prefetch(buffer_size=tf.data.AUTOTUNE)
|
|
)
|
|
|
|
validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
|
|
validation_dataset = (
|
|
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
|
|
.batch(config.BATCH_SIZE)
|
|
.prefetch(buffer_size=tf.data.AUTOTUNE)
|
|
)
|
|
|
|
_, ax = plt.subplots(4, 4, figsize=(10, 5))
|
|
for batch in train_dataset.take(1):
|
|
images = batch["image"]
|
|
labels = batch["label"]
|
|
for i in range(16):
|
|
img = (images[i] * 255).numpy().astype("uint8")
|
|
label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode("utf-8")
|
|
ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray")
|
|
ax[i // 4, i % 4].set_title(label)
|
|
ax[i // 4, i % 4].axis("off")
|
|
plt.show()
|
|
|
|
model = build_model(char_to_num)
|
|
|
|
early_stopping = keras.callbacks.EarlyStopping(
|
|
monitor="val_loss", patience=10, restore_best_weights=True
|
|
)
|
|
|
|
history = model.fit(
|
|
train_dataset,
|
|
validation_data=validation_dataset,
|
|
epochs=config.EPOCHS,
|
|
# steps_per_epoch=math.ceil(len(train_dataset) / config.BATCH_SIZE),
|
|
callbacks = [early_stopping],
|
|
)s=2
|
|
f=9
|
|
s+f
|
|
|
|
|
|
# plt.plot(history.history['loss'])
|
|
# plt.plot(history.history['val_loss'])
|
|
# plt.title('model loss')
|
|
# plt.ylabel('loss')
|
|
# plt.xlabel('epoch')
|
|
# plt.legend(['train', 'val'], loc='upper left')
|
|
# plt.show()
|
|
|
|
prediction_model = keras.models.Model(
|
|
model.input[0], model.get_layer(name="dense2").output
|
|
)
|
|
prediction_model.summary()
|
|
for batch in validation_dataset.take(1):
|
|
batch_images = batch["image"]
|
|
batch_labels = batch["label"]
|
|
|
|
preds = prediction_model.predict(batch_images)
|
|
pred_texts = decode_batch_predictions(preds, max_length, num_to_char)
|
|
|
|
orig_texts = []
|
|
for label in batch_labels:
|
|
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
|
|
orig_texts.append(label)
|
|
|
|
_, ax = plt.subplots(4, 4, figsize=(15, 5))
|
|
for i in range(len(pred_texts)):
|
|
img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)
|
|
img = img.T
|
|
title = f"Prediction: {pred_texts[i]}"
|
|
ax[i // 4, i % 4].imshow(img, cmap="gray")
|
|
ax[i // 4, i % 4].set_title(title)
|
|
ax[i // 4, i % 4].axis("off")
|
|
plt.show()
|
|
model.save("captcha_solver.keras")
|
|
|
|
if __name__ == "__main__":
|
|
main() |