From 9f153eae91a9b8d2af2780275776a7d2e8a83752 Mon Sep 17 00:00:00 2001 From: leca Date: Sun, 4 May 2025 15:48:53 +0300 Subject: [PATCH] added training --- main.py | 61 ++++++++++++++++++++++++++++++++++++++++++------ requirements.txt | 32 +++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 031c908..1609a12 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,10 @@ from dotenv import load_dotenv from base64 import b64decode import re import requests +import tf2onnx +import cv2 +import keras +import numpy as np load_dotenv() @@ -17,14 +21,14 @@ def prepare_dirs(): makedirs(TRAINING_PATH, exist_ok=True) def fetch_captcha(id): - print(f"Fetching captcha with id {id}") + # print(f"Fetching captcha with id {id}") captcha = requests.get(f"{environ.get('CAPTCHA_AGGREGATOR_API')}/captcha/{id}").json()["captcha"] with open(f"{DOWNLOAD_PATH}/{captcha['hash']}_{captcha['solution']}.jpeg", 'wb') as captcha_file: captcha_file.write(b64decode(captcha['image'])) def search_saved_captcha(hash, path): - print(f"searching captcha with hash {hash} in {path}") + # print(f"searching captcha with hash {hash} in {path}") regex = re.compile(hash + '_\\w{6}\\.jpeg') for _, _, files in walk(path): @@ -34,7 +38,7 @@ def search_saved_captcha(hash, path): return False def search_and_download_new(captchas): - print(f"Searching and downloading new captchas") + # print(f"Searching and downloading new captchas") for captcha in captchas: id = captcha["id"] hash = captcha["hash"] @@ -45,12 +49,10 @@ def search_and_download_new(captchas): fetch_captcha(id) def sort_datasets(): - print(f"Sorting datasets") + # print(f"Sorting datasets") percent_of_testing = int(environ.get("PERCENT_OF_TESTING")) amount_of_new_data = len([file for file in listdir(DOWNLOAD_PATH) if path.isfile(f'{DOWNLOAD_PATH}/{file}')]) - print(amount_of_new_data) amount_to_send_to_test = round(amount_of_new_data * (percent_of_testing / 100)) - print(amount_to_send_to_test) for _, _, files in walk(DOWNLOAD_PATH): for index, file in enumerate(files): if index < amount_to_send_to_test: @@ -66,10 +68,55 @@ def download_dataset(): search_and_download_new(captchas) sort_datasets() +def load_dataset(dataset_path): + images = [] + solutions = [] + for filename in listdir(dataset_path): + img = cv2.imread(f"{dataset_path}/{filename}") + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = img / 255.0 + images.append(img) + solution = path.splitext(filename)[0].split('_')[1] + solutions.append(solution) + unique_solutions = sorted(set(solutions)) + solution_to_label = {solution: i for i, solution in enumerate(unique_solutions)} + labels = [solution_to_label[solution] for solution in solutions] + + return images, labels, unique_solutions + +def load_training_dataset(): + return load_dataset(TRAINING_PATH) + +def load_testing_dataset(): + return load_dataset(TESTING_PATH) def train_nn(): - pass + training_images, training_labels, unique_solutions = load_training_dataset() + if int(environ.get("PERCENT_OF_TESTING")) > 0: + testing_images, testing_labels, _ = load_testing_dataset() + + model = keras.Sequential([ + keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(70, 200, 1)), + keras.layers.MaxPooling2D((2, 2)), + keras.layers.Conv2D(64, (3, 3), activation='relu'), + keras.layers.MaxPooling2D((2, 2)), + keras.layers.Conv2D(64, (3, 3), activation='relu'), + keras.layers.Flatten(), + keras.layers.Dense(64, activation='relu'), + keras.layers.Dense(len(unique_solutions), activation='softmax') + ]) + + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + if int(environ.get("PERCENT_OF_TESTING")) > 0: + model.fit(np.array(training_images), np.array(training_labels), epochs=10, batch_size=128, validation_data=(np.array(testing_images), np.array(testing_labels))) + else: + model.fit(np.array(training_images), np.array(training_labels), epochs=10, batch_size=128) + + keras.saving.save_model(model, 'captcha_solver.keras') + # model.save('model.h5') + # tf2onnx.convert.from_keras(model, opset=13, output_path='model_onnx') + if __name__ == "__main__": download_dataset() diff --git a/requirements.txt b/requirements.txt index 5051e6d..25b3c3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,39 @@ +absl-py==2.2.2 +astunparse==1.6.3 certifi==2025.4.26 charset-normalizer==3.4.2 dotenv==0.9.9 +flatbuffers==25.2.10 +gast==0.6.0 +google-pasta==0.2.0 +grpcio==1.71.0 +h5py==3.13.0 idna==3.10 +keras==3.9.2 +libclang==18.1.1 +Markdown==3.8 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +mdurl==0.1.2 +ml_dtypes==0.5.1 +namex==0.0.9 +numpy==2.1.3 +opt_einsum==3.4.0 +optree==0.15.0 +packaging==25.0 +protobuf==5.29.4 +Pygments==2.19.1 python-dotenv==1.1.0 requests==2.32.3 +rich==14.0.0 +setuptools==80.3.0 +six==1.17.0 +tensorboard==2.19.0 +tensorboard-data-server==0.7.2 +tensorflow==2.19.0 +termcolor==3.1.0 +typing_extensions==4.13.2 urllib3==2.4.0 +Werkzeug==3.1.3 +wheel==0.45.1 +wrapt==1.17.2