2021-10-12

Watching a Model Train

Last week, I did a quick hack that quite delighted me: I added a way to visually watch the progress of training my MGL-based neural networks inside Emacs. And then people on twitter asked me to show the code. So, it will be here, but first I wanted to rant a bit about one of my pet peeves.

Low-Tech

In the age of Jupyter and TensorBoard, adding a way to see an image that records the value of a loss function blinking on the screen — "huh, big deal" you would say. But I believe this example showcases a difference between low-tech and high-tech approaches. Just recently I chatted with one of my friends who is entering software engineering at a rather late age (30+), and we talked of how frontend development became even more complicated than backend one (while, arguably, the complexity of tasks solved on the frontend is significantly lower). And that discussion just confirmed to me that the tendency to overcomplicate things is always there, with our pop-culture industry, surely, following it. But I always tried to stay on the simple side, on the side of low-tech solutions. And that's, by the way, one of the reasons I chose to stick with Lisp: with it, you would hardly be forced into some nonsense framework hell, or playing catch-up with the constant changes of your environment, or following crazy "best practices". Lisp is low-tech just like the Unix command-line or vanilla Python or JS. Contrary to the high-tech Rust, Haskell or Java. Everything text-based is also low-tech: text-based data formats, text-based visualization, text-based interfaces.

So, what is low-tech, after all? I saw the term popularized by Kris De Decker from the Low-Tech Magazine, which focuses on using simple (perhaps, outdated by some standards) technologies for solving serious engineering problems. Most people, and the software industry is no exception, are after high-tech, right? Progress of technology enables solving more and more complex tasks. And, indeed, that happens. Sometimes, not always. Sometimes, the whole thing crumbles, but that's a different story. Yet, even when it happens, there's a catch, a negative side-effect: the barrier of entry rises. If 5 or 10 years ago it was enough to know HTML, CSS, and JavaScript to be a competent frontend developer, now you have to learn a dozen more things: convoluted frameworks, complicated deploy toolchains, etc., etc. Surely, sometimes it's inevitable, but it really delights me when you can avoid all the bloat and use simple tools to achieve the same result. OK, maybe not completely the same, maybe not a perfect one. But good enough. The venerable 80% solution that requires 20% effort.

Low-tech is not low-quality, it's low-barrier of entry.

And I would argue that, in the long run, better progress in our field will be made if we strive towards lowering the bar to more people in, than if we continue raising it (ensuring our "job security" this way). Which doesn't mean that the technologies should be primitive (like BASIC). On the contrary, the most ingenious solutions are also the simplest ones. So, I'm going to continue this argument in the future posts I'd like to write about interactive programming. And now, back to our hacks.

Getting to Terms with MGL

In my recent experiments I returned to MGL — an advanced, although pretty opinionated, machine learning library by the prolific Gabor Melis — for playing around with neural networks. Last time, a few years ago I stumbled when I tried to use it to reproduce a very advanced (by that time's standards) recurrent neural network and failed. Yet, before that, I was very happy using it (or rather, it's underlying MGL-MAT library) for running in Lisp (in production) some of the neural networks that were developed by my colleagues. I know it's usually the other way around: Lisp for prototyping, some high-tech monstrosity for production, but we managed to turn the tides for some time :D

So, this time, I decided to approach MGL step by step, starting from simple building blocks. First, I took on training a simple feed-forward net with a number of word inputs converted to vectors using word2vec-like approach.

This is the network I created. Jumping slightly ahead, I've experimented with several variations of the architecture, starting from a single hidden layer MLP, and this one worked the best so far. As you see, it has 2 hidden layers (l1/l1-l and l2/l2-l) and performs 2-class classification. It also uses dropout after each of the layers as a standard means of regularization in the training process.

(defun make-nlp-mlp (&key (n-hidden 100))
  (mgl:build-fnn (:class 'nlp-mlp)
    (in (->input :size *input-len*))
    (l1-l (->activation in :size n-hidden))
    (l1 (->relu l1-l))
    (d1 (->dropout l1 :dropout 0.5))
    (l2-l (->activation d1 :size (floor n-hidden 2)))
    (l2 (->relu l2-l))
    (d2 (->dropout l2 :dropout 0.5))
    (out-l (->activation d2 :size 2))
    (out (->softmax-xe-loss out-l))))

MGL model definition is somewhat different from the approach one might be used to with Keras or TF: you don't imperatively add layers to the network, but, instead, you define all the layers at once in a declarative fashion. A typical Lisp style it is. Yet, what still remains not totally clear to me yet, is the best way to assemble layers when the architecture is not a straightforward one-direction or recurrent, but combines several parts in nonstandard ways. That's where I stumbled previously. I plan to get to that over time, but if someone has good examples already, I'd be glad to take a look at those. Unfortunately, despite the proven high-quality of MGL, there's very little open-source code that uses it.

Now, to make a model train (and watch it), we have to pass it to mgl:minimize alongside with a learner:

(defun train-nlp-fnn (&key data (batch-size 100) (epochs 1000) (n-hidden 100)
                       (random-state *random-state*))
  (let ((*random-state* random-state)
        (*agg-loss* ())
        (opt (make 'mgl:segmented-gd-optimizer
                   :termination (* epochs batch-size)
                   :segmenter (constantly
                                (make 'mgl:adam-optimizer
                                      :n-instances-in-batch batch-size))))
        (fnn (make-nlp-mlp :n-hidden n-hidden)))
    (mgl:map-segments (lambda (layer)
                        (mgl:gaussian-random!
                         (mgl:nodes layer)
                         :stddev (/ 2 (reduce '+ (mgl:mat-dimensions (mgl:nodes layer))))))
                      fnn)
    (mgl:monitor-optimization-periodically
     opt
     `((:fn mgl:reset-optimization-monitors :period ,batch-size :last-eval 0)
       (:fn draw-test-error :period ,batch-size)))
    (mgl:minimize opt (make 'mgl:bp-learner
                            :bpn fnn
                            :monitors (mgl:make-cost-monitors
                                       fnn :attributes `(:event "train")))
                  :dataset (sample-data data (* epochs batch-size)))
    fnn))

This code is rather complex, so let me try to explain each part.

  • We use (let ((*random-state* random-state) to ensure that we can reproduce training in exactly the same way if needed.
  • mgl:segmented-gd-optimizer is a class that allows us to specify a different optimization algorithm for each segment (layer) of the network. Here we use the same standard mgl:adam-optimizer with vanilla parameters for each segment (constantly).
  • The following mgl:map-segments call is performing the Xavier initialization of the input layers. It is crucial to properly initialize the layers of the network before training or, at least, ensure that they are not all set to zeroes.
  • The next part is, finally, responsible for WATCHING THE MODEL TRAIN. mgl:monitor-optimization-periodically is a hook to make MGL invoke some callbacks that will help you peek into the optimization process (and, perhaps, do other needful things). That's where we insert our draw-test-error function that will run each batch. There's also an out-of-the-box cost-monitor attached directly to the mgl:bp-learner, which is collecting the data for us and also printing it on the screen. I guess, we could build the draw-test-error monitor in a similar way, but I opted for my favorite Lisp magic wand — a special variable *agg-loss*.
  • And last but not least, we need to provide the dataset to the model: (sample-adata data (* epochs batch-size)). The simple approach that I use here is to pre-sample the necessary number of examples beforehand. However, streaming sampling may be also possible with a different dataset-generating function.

Now, let's take a look at the function that is drawing the graph:

(defun draw-test-error (opt learner)
  ;; here, we print out the architecture and parameters of
  ;; our model and learning algorithm
  (when (zerop (mgl:n-instances opt))
    (describe opt)
    (describe (mgl:bpn learner)))
  ;; here, we rely on the fact that there's
  ;; just a single cost monitor defined
  (let ((mon (first (mgl:monitors learner))))
    ;; using some of RUTILS syntax sugar here to make the code terser
    (push (pair (+ (? mon 'counter 'denominator)
                   (if-it (first *agg-loss*)
                          (lt it)
                          0))
                (? mon 'counter 'numerator))
          *agg-loss*)
    (redraw-loss-graph)))

(defun redraw-loss-graph (&key (file "/tmp/loss.png") (smoothing 10))
  (adw-charting:with-chart (:line 800 600)
    (adw-charting:add-series "Loss" *agg-loss*)
    (adw-charting:add-series
     (fmt "Smoothed^~a Loss" smoothing)
     (loop :for i :from 0
           :for off := (* smoothing (1+ i))
           :while (< off (length *agg-loss*))
           :collect (pair (? *agg-loss* (- off (floor smoothing 2)) 0)
                          (/ (reduce ^(+ % (rt %%))
                                     (subseq *agg-loss* (- off smoothing) off)
                                     :initial-value 0)
                             smoothing))))
    (adw-charting:set-axis :y "Loss" :draw-gridlines-p t)
    (adw-charting:set-axis :x "Iteration #")
    (adw-charting:save-file file)))

Using this approach, I could also draw the change of the validation loss on the same graph. And I'll do that in the next version.

ADW-CHARTING is my goto-library when I need to draw a quick-and-dirty chart. As you see, it is very straightforward to use and doesn't require a lot of explanation. I've looked into a couple other charting libraries and liked their demo screenshots (probably, more than the style of ADW-CHARTING), but there were some blockers that prevented me from switching to them. Maybe, next time, I'll have more inclination.  

To complete the picture we now need to display our learning progress not just with text running in the console (produced by the standard cost-monitor), but also by updating the graph. This is where Emacs' nature of a swiss-army knife for any interactive workflow came into play. Surely, there was already an existing auto-revert-mode that updates the contents of a Emacs buffer on any change or periodically. For my purposes, I've added this lines to my Emacs config:

(setq auto-revert-use-notify nil)
(setq auto-revert-interval 6)  ; refresh every seconds

Obviously, this can be abstracted away into a function which could be invoked by pressing some key or upon other conditions occurring.

No comments:

Post a Comment