]> git.vouivredigital.com Git - vouivre.git/commitdiff
Train a model on MNIST
authoradmin <admin@vouivredigital.com>
Mon, 27 Nov 2023 12:03:44 +0000 (21:03 +0900)
committeradmin <admin@vouivredigital.com>
Mon, 27 Nov 2023 12:03:44 +0000 (21:03 +0900)
examples/base.scm
examples/example.scm [new file with mode: 0644]

index 90a88a897568788743093eba99445da6239254a6..dc637cf212c4d0f2cfce429d5d454a145c4a82c3 100644 (file)
@@ -20,7 +20,6 @@
   #:use-module ((rnrs base) #:prefix rnrs:)
   #:use-module ((srfi srfi-1) #:prefix srfi-1:)
   #:use-module (vouivre curry)
-  #:use-module ((vouivre autodiff) #:prefix v:)
   #:export
   (∘
    ⊙
 (cudefine (log x) (rnrs:log x))
 
 (∷ sin (0 . 0))
-(cudefine (sin x) (v:sin x))
+(cudefine (sin x) (rnrs:sin x))
 
 (∷ cos (0 . 0))
 (cudefine (cos x) (rnrs:cos x))
 (definec (∘ g f) (λc x (g (f x))))
 (definec (⊙ f g) (∘ g f))
 (definec (flip f) (λc y (λc x (f x y))))
-
-(define fdiff v:fdiff)
-(define rdiff v:rdiff)
diff --git a/examples/example.scm b/examples/example.scm
new file mode 100644 (file)
index 0000000..bf694c4
--- /dev/null
@@ -0,0 +1,176 @@
+;;;; Copyright (C) 2023 Vouivre Digital Corporation
+;;;;
+;;;; This file is part of Vouivre.
+;;;;
+;;;; Vouivre is free software: you can redistribute it and/or
+;;;; modify it under the terms of the GNU General Public
+;;;; License as published by the Free Software Foundation, either
+;;;; version 3 of the License, or (at your option) any later version.
+;;;;
+;;;; Vouivre is distributed in the hope that it will be useful,
+;;;; but WITHOUT ANY WARRANTY; without even the implied warranty of
+;;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+;;;; General Public License for more details.
+;;;;
+;;;; You should have received a copy of the GNU General Public
+;;;; License along with Vouivre. If not, see <https://www.gnu.org/licenses/>.
+
+(use-modules
+ ((vouivre autodiff) #:prefix v:)
+ ((vouivre misc) #:prefix v:)
+ ((vouivre mnist) #:prefix v:)
+ (ice-9 receive)
+ (srfi srfi-1))
+
+;;;; utilities
+
+(define (iter f a n)
+  (if (zero? n)
+      a
+      (iter f (f a) (1- n))))
+
+(define (make-randn-array . dims)
+  (apply v:produce-typed-array
+         (lambda _ (random:normal))
+        v:*atype*
+         dims))
+
+(define (relu x)
+  (v:max 0 x))
+
+(define (logsumexp x)
+  (let ((c (v:maximum x)))
+    (v:+ c (v:log (v:sum (v:exp (v:- x c)))))))
+
+(define (random-batch batch-size array . more)
+  "Return random batches of the given size, one for each array, as multiple
+values. The indices of elements of a batch, in their array of origin,
+correspond batch-wise.
+
+Example: (random-batch 2 #2((a b) (c d) (e f)) #(x y z)) might return
+
+         (values #2((a b) (a b)) #(x x))
+                        or
+         (values #2((c d) (a b)) #(y x))
+etc."
+  (let* ((arrays (cons array more))
+        (n (car (array-dimensions array)))
+        (rs (map (lambda (src)
+                   (apply make-typed-array (array-type src)
+                          *unspecified* batch-size
+                          (cdr (array-dimensions src))))
+                 arrays)))
+    (let lp ((batch 0))
+      (if (= batch batch-size)
+         (apply values rs)
+         (let ((i (random n)))
+           (for-each
+            (lambda (dst src)
+              (array-cell-set!
+               dst
+               (array-cell-ref src i)
+               batch))
+            rs arrays)
+           (lp (1+ batch)))))))
+
+(define (argmax x)
+  (let ((i 0)
+       (mi #f)
+       (m (- (inf))))
+    (array-for-each
+     (lambda (x)
+       (when (> x m)
+        (set! m x)
+        (set! mi i))
+       (set! i (1+ i)))
+     x)
+    mi))
+
+;;;; parameters
+
+(define bs 2)                          ; batch size
+(define es 10)                         ; encoding size
+(define hl 16)                         ; hidden layer size
+(define lr 5e-3)                       ; learning rate
+(define w 'TBD)                                ; image width
+(define h 'TBD)                                ; image height
+
+;;;; data
+
+(define data (v:load-mnist 10 #t))
+(define (cast-array x)
+  (apply
+   v:produce-typed-array
+   (lambda indices
+     (apply array-ref x indices))
+   v:*atype*
+   (array-dimensions x)))
+(define (normalize x)
+  (v:- ((v:extend /) x 255)
+       0.5))
+
+(set! h (list-ref (array-dimensions (car data))
+                 1))
+(set! w (list-ref (array-dimensions (car data))
+                 2))
+
+(define x (normalize (cast-array (car data)))) ; training samples
+(define y (cdr data))                         ; training labels
+(define a (list (make-randn-array hl h w)      ; the model parameters
+               (make-randn-array hl hl)
+                (make-randn-array es hl)))
+
+;;;; training
+
+(define (f x a)
+  "The model which, given a sample and some parameters, produces a distribution.
+The log-likelihood of each label, here."
+  (fold (lambda (a prev)
+         (v:adot a (relu prev) 1))
+       (v:adot (car a) x 2)
+       (cdr a)))
+
+(define (L y p)
+  "The loss as a function of the ground truth label and the model's output."
+  (v:- (logsumexp p)
+       (v:array-ref p y)))
+
+(define (<L> x y)
+  "The average loss over multiple samples and labels as a function of the model
+parameters."
+  (lambda a
+    (v:mean
+     (v:amap2
+      (lambda (x y)
+       (L y (f x a)))
+      x y))))
+
+(define (∇L a)
+  "Mini-batch gradient of the loss over model parameters."
+  (receive (x y) (random-batch bs x y)
+    (list-tabulate (length a)
+                  (lambda (i)
+                    (apply (v:rdiff (<L> x y) i)
+                           a)))))
+
+(define (update a)
+  "Update model parameters, moving a small step opposite to the mini-batch
+gradient (without overwriting the original parameters)."
+  (map (lambda (a da)
+        (v:- a (v:* lr da)))
+       a (∇L a)))
+
+(define (report-loss a)
+  "Display the average loss over the entire dataset."
+  (display (apply (<L> x y) a))
+  (newline))
+
+(define (categorize x)
+  (argmax (f x a)))
+
+(define (run nb-epochs steps-per-epochs)
+  (report-loss a)
+  (do ((i 0 (1+ i)))
+      ((= i nb-epochs))
+    (set! a (iter update a steps-per-epochs))
+    (report-loss a)))