Caching Deep Learning Models with TensorFlow, Keras and PTMLib

Save Machine Learning models, time, and energy (and maybe even the planet)

I’ve noticed when going through lengthy ML tutorials for Keras and TensorFlow that I often run the same models repeatedly in Jupyter Notebook, regardless of the fact they have no changes. This is natural for me with a notebook with multiple models. I also find that typing the code into a new notebook, rather than just clicking Shift + Enter, aids my understanding and retention of what I have learned.

These polar bears want you to save your trained models.
Photo by Hans-Jurgen Mager on Unsplash

There can be many stops and starts as I work through the details, and research topics I want to understand more before proceeding. And there are always those security patches: shut down everything and reboot your machine. Go back to your Jupyter Notebook and hit Kernel > Restart & Run All.

To be fair, you can’t have enough security these days.

How long will all those models take to run again?

Most of the time this is not a big deal, but sometimes, even in a lesson notebook, you have models and data that require hours of training. Even simple models could take several minutes. This means wasted time, and energy in the form of compute resources. It may not seem like a lot, but after a while this all adds up. We have a planet to think of here!

How can you avoid all this waste without much effort? Save your models so you can reload them when the notebook runs again.

Saving your TensorFlow model

In this example using TensorFlow and Keras, I trained the model first and saved it once training completed. I can then reload the model later. This requires commenting and uncommenting the appropriate code. Simple, but crude.

I found myself using the above code repeatedly, including enhancements to check if the saved file exists to avoid commenting code. Then I added more code to check if I had saved images for accuracy/loss charts. Now things are getting more complicated, and repetitive.

Repetition is a good sign that it’s time for some reusable code.

Saving your model using PTMLib

In the following example, the new load_or_fit_model() function in PTMLib takes care of training the model and saving it once training completes, or loading it if the model was previously saved. It also saves/loads accuracy and loss charts.

I include an early callback function to stop training when we reach desired accuracy. We don’t lose the power and flexibility of TensorFlow and Keras if we want to cache our models.

Example output using load_or_fit_model()

A complete notebook including this example is available here: Computer Vision with Model Caching.

I started working on PTMLib when I found myself using the same blocks of helper code in Jupyter Notebook repeatedly. I have found that it saves me time and compute energy, which adds up after a while. I hope you find it useful too.