--- /dev/null
+;;;; Copyright (C) 2023 Vouivre Digital Corporation
+;;;;
+;;;; This file is part of Vouivre.
+;;;;
+;;;; Vouivre is free software: you can redistribute it and/or
+;;;; modify it under the terms of the GNU General Public
+;;;; License as published by the Free Software Foundation, either
+;;;; version 3 of the License, or (at your option) any later version.
+;;;;
+;;;; Vouivre is distributed in the hope that it will be useful,
+;;;; but WITHOUT ANY WARRANTY; without even the implied warranty of
+;;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+;;;; General Public License for more details.
+;;;;
+;;;; You should have received a copy of the GNU General Public
+;;;; License along with Vouivre. If not, see <https://www.gnu.org/licenses/>.
+
+(use-modules
+ ((vouivre autodiff) #:prefix v:)
+ ((vouivre misc) #:prefix v:)
+ ((vouivre mnist) #:prefix v:)
+ (ice-9 receive)
+ (srfi srfi-1))
+
+;;;; utilities
+
+(define (iter f a n)
+ (if (zero? n)
+ a
+ (iter f (f a) (1- n))))
+
+(define (make-randn-array . dims)
+ (apply v:produce-typed-array
+ (lambda _ (random:normal))
+ v:*atype*
+ dims))
+
+(define (relu x)
+ (v:max 0 x))
+
+(define (logsumexp x)
+ (let ((c (v:maximum x)))
+ (v:+ c (v:log (v:sum (v:exp (v:- x c)))))))
+
+(define (random-batch batch-size array . more)
+ "Return random batches of the given size, one for each array, as multiple
+values. The indices of elements of a batch, in their array of origin,
+correspond batch-wise.
+
+Example: (random-batch 2 #2((a b) (c d) (e f)) #(x y z)) might return
+
+ (values #2((a b) (a b)) #(x x))
+ or
+ (values #2((c d) (a b)) #(y x))
+etc."
+ (let* ((arrays (cons array more))
+ (n (car (array-dimensions array)))
+ (rs (map (lambda (src)
+ (apply make-typed-array (array-type src)
+ *unspecified* batch-size
+ (cdr (array-dimensions src))))
+ arrays)))
+ (let lp ((batch 0))
+ (if (= batch batch-size)
+ (apply values rs)
+ (let ((i (random n)))
+ (for-each
+ (lambda (dst src)
+ (array-cell-set!
+ dst
+ (array-cell-ref src i)
+ batch))
+ rs arrays)
+ (lp (1+ batch)))))))
+
+(define (argmax x)
+ (let ((i 0)
+ (mi #f)
+ (m (- (inf))))
+ (array-for-each
+ (lambda (x)
+ (when (> x m)
+ (set! m x)
+ (set! mi i))
+ (set! i (1+ i)))
+ x)
+ mi))
+
+;;;; parameters
+
+(define bs 2) ; batch size
+(define es 10) ; encoding size
+(define hl 16) ; hidden layer size
+(define lr 5e-3) ; learning rate
+(define w 'TBD) ; image width
+(define h 'TBD) ; image height
+
+;;;; data
+
+(define data (v:load-mnist 10 #t))
+(define (cast-array x)
+ (apply
+ v:produce-typed-array
+ (lambda indices
+ (apply array-ref x indices))
+ v:*atype*
+ (array-dimensions x)))
+(define (normalize x)
+ (v:- ((v:extend /) x 255)
+ 0.5))
+
+(set! h (list-ref (array-dimensions (car data))
+ 1))
+(set! w (list-ref (array-dimensions (car data))
+ 2))
+
+(define x (normalize (cast-array (car data)))) ; training samples
+(define y (cdr data)) ; training labels
+(define a (list (make-randn-array hl h w) ; the model parameters
+ (make-randn-array hl hl)
+ (make-randn-array es hl)))
+
+;;;; training
+
+(define (f x a)
+ "The model which, given a sample and some parameters, produces a distribution.
+The log-likelihood of each label, here."
+ (fold (lambda (a prev)
+ (v:adot a (relu prev) 1))
+ (v:adot (car a) x 2)
+ (cdr a)))
+
+(define (L y p)
+ "The loss as a function of the ground truth label and the model's output."
+ (v:- (logsumexp p)
+ (v:array-ref p y)))
+
+(define (<L> x y)
+ "The average loss over multiple samples and labels as a function of the model
+parameters."
+ (lambda a
+ (v:mean
+ (v:amap2
+ (lambda (x y)
+ (L y (f x a)))
+ x y))))
+
+(define (∇L a)
+ "Mini-batch gradient of the loss over model parameters."
+ (receive (x y) (random-batch bs x y)
+ (list-tabulate (length a)
+ (lambda (i)
+ (apply (v:rdiff (<L> x y) i)
+ a)))))
+
+(define (update a)
+ "Update model parameters, moving a small step opposite to the mini-batch
+gradient (without overwriting the original parameters)."
+ (map (lambda (a da)
+ (v:- a (v:* lr da)))
+ a (∇L a)))
+
+(define (report-loss a)
+ "Display the average loss over the entire dataset."
+ (display (apply (<L> x y) a))
+ (newline))
+
+(define (categorize x)
+ (argmax (f x a)))
+
+(define (run nb-epochs steps-per-epochs)
+ (report-loss a)
+ (do ((i 0 (1+ i)))
+ ((= i nb-epochs))
+ (set! a (iter update a steps-per-epochs))
+ (report-loss a)))