This repo contains code for teaching a neural network based reinforcement learning agent how to write characters (see Figure 1 below).
Fig 1. Example of an ACKTR reinforcement learning agent trained on 5x5 patterns for about 44 million steps.
For training reinforcement learning agents I use stable-baselines, and for the environment I use my own custom gym environment. The environment provided has three sets of patterns (mostly letters and digits) that an agent can be trained on:
- A set of simple 3x3 patterns
- A set of 5x5 patterns
- Letters and digits from EMNIST dataset (Extended MNIST dataset). This is essentially MNIST with both digits and letters (see Figure 2 below).
Fig 2. Sample of images from the EMNIST dataset.
The goal is for the agent to fill in the squares in a grid to reproduce the pattern that it has been presented as accurately as possible.
There is a convenience script setup.sh
which assumes a UNIX-based system and automates most of the setup process.
If you use this script then restart your terminal after successfully running it (to get conda set up correctly) and
skip to step 3.
-
Install the required system packages:
sudo apt-get update && sudo apt-get install cmake libopenmpi-dev python3-dev zlib1g-dev unzip xvfb python-opengl
See the prerequisites section of stable-baselines.readthedocs.io for instructions for other operating systems.
-
Set up the python environment using conda:
conda env create -f environment.yml
or if you are not using conda, then make sure you have a python environment set up with all of the packages listed in the file
environment.yml
. -
Activate the conda environment:
conda activate learning2write
-
If you want to train an agent on images from the EMNIST dataset then you will need it handy. You can acquire the dataset by running the following:
mkdir emnist_data cd emnist_data wget http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip unzip gzip.zip mv gzip/* . rmdir gzip cd ..
The download is about 550MB.
-
Train a model:
python train.py -steps 1000000
Training the agent for more steps usually provides better results :) An ACKTR agent requires roughly 5 ~ 10 million steps for 3x3 patterns, and 10 ~ 20 million steps for 5x5 patterns. Not sure for the EMNIST dataset, but probably a lot more :|
-
Test a previously trained model:
python test.py models/acktr_mlp_5x5.pkl acktr
This opens a window that displays the environment and the reference/target pattern on the left, and the agent's drawing and its location (the big red dot).
There are a couple of pretrained models in the
models/
directory. -
You can see the help text for these scripts by adding the flag
-h
or--help
.