Visualizing Training of Neural Networks

When I started working on neural networks deep learning projects, I was always curious about whats going on under the hood, observing TensorFlow’s output was not enough to realize how good my model really trains and it doesn’t provide enough clarity of how the model progresses towards the goal I want it to learn.

After struggling with it a bit I implemented a solution I see as fits my needs, what I did was adding a callback to the training phase which would draw my model guesses in the same plot with the target it needs to learn, the plot will be updated on after each train batch was learned, and it would show exactly what it learned by drawing the model predictions (orange line) and the true target (blue line).

The results would look something like that:

This how I create the plot callback :

					def create_plot_callback(model, batch_size, train_X, train_y):
        def cb(x, y_true):
            def _(batch, logs):
                s, e = batch * batch_size, (batch + 1) * batch_size
                y_pred = model.predict(
                for i in range(y_pred.shape[1]):
                     plt.subplot(y_pred.shape[1], 1, i + 1)
                     plt.plot(train_y[s:e, i], label='true' + str(i))
                     plt.plot(y_pred[:, i], label='pred' + str(i))
            return _
        return keras.callbacks.LambdaCallback(on_batch_end=cb(train_X, train_y))


This how I use it when I train my keras model:

, y, ... batch_size=batch_size, callbacks=[create_plot_callback(model, batch_size, x, y)])

Note that x,y that I pass to training are the same x,y that being passed to the callback creation so it can draw the appropriate batch while training.

CAUTION though, this may make your training much slower, but for me its a price I am willing to pay just to always know exactly how’s my model training go.

Did my post help you? share it with me here and share it with your friends and colleagues everywhere 🙂

Leave a Reply

Your email address will not be published. Required fields are marked *