71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
import math
|
|
import keras
|
|
|
|
from keras.src import layers
|
|
|
|
from src.config import config
|
|
from src.ctc import CTCLayer
|
|
|
|
|
|
def build_model(char_to_num):
|
|
input_img = layers.Input(
|
|
shape=(config.IMAGE_WIDTH, config.IMAGE_HEIGHT, 1), name="image", dtype="float32"
|
|
)
|
|
labels = layers.Input(name="label", shape=(None,), dtype="float32")
|
|
# Conv
|
|
x = layers.Conv2D(
|
|
32,
|
|
(3, 3),
|
|
activation="relu",
|
|
kernel_initializer="he_normal",
|
|
padding="same",
|
|
name="Conv1",
|
|
)(input_img)
|
|
x = layers.MaxPooling2D((2, 2), name="Pool1")(x)
|
|
|
|
x = layers.Conv2D(
|
|
64,
|
|
(3, 3),
|
|
activation="relu",
|
|
kernel_initializer="he_normal",
|
|
padding="same",
|
|
name="Conv2",
|
|
)(x)
|
|
x = layers.MaxPooling2D((2, 2), name="Pool2")(x)
|
|
|
|
# x = layers.Conv2D(
|
|
# 128,
|
|
# (3, 3),
|
|
# activation="relu",
|
|
# kernel_initializer="he_normal",
|
|
# padding="same",
|
|
# name="Conv3",
|
|
# )(x)
|
|
# x = layers.MaxPooling2D((2, 2), name="Pool3")(x)
|
|
# new_shape = ((config.IMAGE_WIDTH // 4), round((config.IMAGE_HEIGHT // 4) * 30.11764705882353))
|
|
new_shape = ((config.IMAGE_WIDTH // 4), math.floor((config.IMAGE_HEIGHT // 4) * 64))
|
|
|
|
x = layers.Reshape(target_shape=new_shape, name="reshape")(x)
|
|
x = layers.Dense(64, activation="relu", name="dense1")(x)
|
|
x = layers.Dropout(0.2)(x)
|
|
|
|
# RNNs
|
|
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)
|
|
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)
|
|
|
|
# Output layer
|
|
x = layers.Dense(
|
|
len(char_to_num.get_vocabulary()) + 1, activation="softmax", name="dense2"
|
|
)(x)
|
|
|
|
# Add CTC layer for calculating CTC loss at each step
|
|
output = CTCLayer(name="ctc_loss")(labels, x)
|
|
|
|
# Define the model
|
|
model = keras.models.Model(
|
|
inputs=[input_img, labels], outputs=output, name="captcha_solver"
|
|
)
|
|
|
|
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001))
|
|
return model
|