From 6f1014b35802ef5a0a6618dca7ba61d5b3f03eb7 Mon Sep 17 00:00:00 2001 From: admin Date: Mon, 27 Nov 2023 21:03:44 +0900 Subject: [PATCH] Train a model on MNIST --- examples/base.scm | 6 +- examples/example.scm | 176 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 examples/example.scm diff --git a/examples/base.scm b/examples/base.scm index 90a88a8..dc637cf 100644 --- a/examples/base.scm +++ b/examples/base.scm @@ -20,7 +20,6 @@ #:use-module ((rnrs base) #:prefix rnrs:) #:use-module ((srfi srfi-1) #:prefix srfi-1:) #:use-module (vouivre curry) - #:use-module ((vouivre autodiff) #:prefix v:) #:export (∘ ⊙ @@ -365,7 +364,7 @@ (cudefine (log x) (rnrs:log x)) (∷ sin (0 . 0)) -(cudefine (sin x) (v:sin x)) +(cudefine (sin x) (rnrs:sin x)) (∷ cos (0 . 0)) (cudefine (cos x) (rnrs:cos x)) @@ -606,6 +605,3 @@ (definec (∘ g f) (λc x (g (f x)))) (definec (⊙ f g) (∘ g f)) (definec (flip f) (λc y (λc x (f x y)))) - -(define fdiff v:fdiff) -(define rdiff v:rdiff) diff --git a/examples/example.scm b/examples/example.scm new file mode 100644 index 0000000..bf694c4 --- /dev/null +++ b/examples/example.scm @@ -0,0 +1,176 @@ +;;;; Copyright (C) 2023 Vouivre Digital Corporation +;;;; +;;;; This file is part of Vouivre. +;;;; +;;;; Vouivre is free software: you can redistribute it and/or +;;;; modify it under the terms of the GNU General Public +;;;; License as published by the Free Software Foundation, either +;;;; version 3 of the License, or (at your option) any later version. +;;;; +;;;; Vouivre is distributed in the hope that it will be useful, +;;;; but WITHOUT ANY WARRANTY; without even the implied warranty of +;;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +;;;; General Public License for more details. +;;;; +;;;; You should have received a copy of the GNU General Public +;;;; License along with Vouivre. If not, see . + +(use-modules + ((vouivre autodiff) #:prefix v:) + ((vouivre misc) #:prefix v:) + ((vouivre mnist) #:prefix v:) + (ice-9 receive) + (srfi srfi-1)) + +;;;; utilities + +(define (iter f a n) + (if (zero? n) + a + (iter f (f a) (1- n)))) + +(define (make-randn-array . dims) + (apply v:produce-typed-array + (lambda _ (random:normal)) + v:*atype* + dims)) + +(define (relu x) + (v:max 0 x)) + +(define (logsumexp x) + (let ((c (v:maximum x))) + (v:+ c (v:log (v:sum (v:exp (v:- x c))))))) + +(define (random-batch batch-size array . more) + "Return random batches of the given size, one for each array, as multiple +values. The indices of elements of a batch, in their array of origin, +correspond batch-wise. + +Example: (random-batch 2 #2((a b) (c d) (e f)) #(x y z)) might return + + (values #2((a b) (a b)) #(x x)) + or + (values #2((c d) (a b)) #(y x)) +etc." + (let* ((arrays (cons array more)) + (n (car (array-dimensions array))) + (rs (map (lambda (src) + (apply make-typed-array (array-type src) + *unspecified* batch-size + (cdr (array-dimensions src)))) + arrays))) + (let lp ((batch 0)) + (if (= batch batch-size) + (apply values rs) + (let ((i (random n))) + (for-each + (lambda (dst src) + (array-cell-set! + dst + (array-cell-ref src i) + batch)) + rs arrays) + (lp (1+ batch))))))) + +(define (argmax x) + (let ((i 0) + (mi #f) + (m (- (inf)))) + (array-for-each + (lambda (x) + (when (> x m) + (set! m x) + (set! mi i)) + (set! i (1+ i))) + x) + mi)) + +;;;; parameters + +(define bs 2) ; batch size +(define es 10) ; encoding size +(define hl 16) ; hidden layer size +(define lr 5e-3) ; learning rate +(define w 'TBD) ; image width +(define h 'TBD) ; image height + +;;;; data + +(define data (v:load-mnist 10 #t)) +(define (cast-array x) + (apply + v:produce-typed-array + (lambda indices + (apply array-ref x indices)) + v:*atype* + (array-dimensions x))) +(define (normalize x) + (v:- ((v:extend /) x 255) + 0.5)) + +(set! h (list-ref (array-dimensions (car data)) + 1)) +(set! w (list-ref (array-dimensions (car data)) + 2)) + +(define x (normalize (cast-array (car data)))) ; training samples +(define y (cdr data)) ; training labels +(define a (list (make-randn-array hl h w) ; the model parameters + (make-randn-array hl hl) + (make-randn-array es hl))) + +;;;; training + +(define (f x a) + "The model which, given a sample and some parameters, produces a distribution. +The log-likelihood of each label, here." + (fold (lambda (a prev) + (v:adot a (relu prev) 1)) + (v:adot (car a) x 2) + (cdr a))) + +(define (L y p) + "The loss as a function of the ground truth label and the model's output." + (v:- (logsumexp p) + (v:array-ref p y))) + +(define ( x y) + "The average loss over multiple samples and labels as a function of the model +parameters." + (lambda a + (v:mean + (v:amap2 + (lambda (x y) + (L y (f x a))) + x y)))) + +(define (∇L a) + "Mini-batch gradient of the loss over model parameters." + (receive (x y) (random-batch bs x y) + (list-tabulate (length a) + (lambda (i) + (apply (v:rdiff ( x y) i) + a))))) + +(define (update a) + "Update model parameters, moving a small step opposite to the mini-batch +gradient (without overwriting the original parameters)." + (map (lambda (a da) + (v:- a (v:* lr da))) + a (∇L a))) + +(define (report-loss a) + "Display the average loss over the entire dataset." + (display (apply ( x y) a)) + (newline)) + +(define (categorize x) + (argmax (f x a))) + +(define (run nb-epochs steps-per-epochs) + (report-loss a) + (do ((i 0 (1+ i))) + ((= i nb-epochs)) + (set! a (iter update a steps-per-epochs)) + (report-loss a))) -- 2.39.5