]> git.vouivredigital.com Git - vouivre.git/commitdiff
Implement an automatic differentiation engine
authoradmin <admin@vouivredigital.com>
Wed, 1 Nov 2023 10:54:19 +0000 (19:54 +0900)
committeradmin <admin@vouivredigital.com>
Wed, 1 Nov 2023 10:54:19 +0000 (19:54 +0900)
grad-tests.scm [new file with mode: 0644]
grad.scm [new file with mode: 0644]

diff --git a/grad-tests.scm b/grad-tests.scm
new file mode 100644 (file)
index 0000000..4fb7a24
--- /dev/null
@@ -0,0 +1,325 @@
+(define-module (vouivre grad tests)
+  #:use-module ((vouivre grad) #:prefix v:)
+  #:use-module (ice-9 receive)
+  #:use-module (srfi srfi-1)
+  #:use-module (srfi srfi-64)
+  #:use-module (vouivre misc)
+  #:export
+  (apply-grad
+   apply-grad-amap2
+   a~
+   differentiable-func-generator
+   lambda-const-call
+   ngrad
+   n~
+   random-array
+   random-array-shape
+   random-func1
+   random-func2
+   random-input
+   random-list-element
+   random-non-empty-array
+   random-shape
+   random-shared
+   random-shared-array-rank&dims>0
+   with-generators
+   ~))
+
+(define f1s (list v:exp v:identity))
+(define f2s (list v:* v:- v:max))
+
+(define-syntax-rule (with-generators (g1 g2 ...) equal expected given)
+  (let ((times 100)
+       (fx expected)
+       (fy given)
+       (generators (list g1 g2 ...)))
+    (call/cc
+     (lambda (break)
+       (do ((i 0 (1+ i)))
+          ((= i times) #t)
+        (let ((zs (map-in-order (lambda (g) (g)) generators)))
+          (let ((r1 (apply fx zs))
+                (r2 (apply fy zs)))
+            (unless (equal r1 r2)
+              (break #f zs r1 r2)))))))))
+
+(define (lambda-const-call f . consts)
+  (lambda _
+    (apply f consts)))
+
+(define* (random-array-shape
+         #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5))
+  (list-tabulate (+ min-rank (random (- max-rank min-rank)))
+                (lambda _ (+ min-dim (random (- max-dim min-dim))))))
+
+(define (random-shape)
+  (if (= 0 (random 2))
+      0
+      (random-array-shape)))
+
+(define* (random-array #:optional shape)
+  (let ((a (apply make-array 0 (or shape (random-array-shape)))))
+    (array-map! a random:uniform)
+    a))
+
+(define (random-non-empty-array)
+  "Random array of at least one element."
+  (random-array (random-array-shape 0 5 1 5)))
+
+(define* (random-input #:optional shape)
+  (let ((shape (or shape (random-shape))))
+    (if (eq? 0 shape)
+       (random:uniform)
+       (random-array shape))))
+
+(define (random-shared)
+  (let ((shape (random-shape)))
+    (values
+     (lambda ()
+       (random-input shape))
+     (lambda ()
+       (let ((x (random-input
+                (random-list-element
+                 (list 0 (if (list? shape)
+                             shape
+                             (random-shape)))))))
+        (set! shape (random-shape))
+        x)))))
+
+(define (random-shared-array-rank&dims>0)
+  (let ((shape (random-array-shape 1 5 1 5)))
+    (values
+     (lambda ()
+       (random-array shape))
+     (lambda ()
+       (let ((x (random-array shape)))
+        (set! shape (random-array-shape 1 5 1 5))
+        x)))))
+
+(define (random-list-element lst)
+  (list-ref lst (random (length lst))))
+
+(define (differentiable-func-generator lst . input-generators)
+  (lambda ()
+    (random-list-element
+     (cons
+      (apply
+       lambda-const-call
+       (random-list-element lst)
+       (map (lambda (g) (g))
+           input-generators))
+      lst))))
+
+(define random-func1
+  (differentiable-func-generator f1s random-input))
+(define random-func2
+  (receive (gx gy) (random-shared)
+    (differentiable-func-generator f2s gx gy)))
+
+(define* (n~ x y #:optional (error 1e-4))
+  (and
+   (>= y (- x error))
+   (<= y (+ x error))))
+
+(define* (a~ x y #:optional (error 1e-4))
+ (and
+  (equal? (array-dimensions x)
+         (array-dimensions y))
+  (call/cc
+   (lambda (break)
+     (array-for-each
+      (lambda (x y)
+       (unless (~ x y error)
+         (break #f)))
+      x y)
+     #t))))
+
+(define* (~ x y #:optional (error 1e-4))
+  (cond
+   ((and (number? x) (number? y))
+    (n~ x y error))
+   ((and (array? x) (array? y))
+    (a~ x y error))
+   (else #f)))
+
+(define* (ngrad f #:optional (axis 0) (step 1e-6))
+  "Gradient using a numerical centered difference approximation."
+  (define (axis-add xs dh . indices)
+    "Add `dh' to the number or array at the given `axis' of `xs',
+and, when it's an array, at the given index."
+    (map-indexed
+     (lambda (x i)
+       (ifn (= i axis)
+           x
+           (if (number? x)
+               (+ x dh)
+               (array-map-indexed
+                (lambda (x . indices_)
+                  (ifn (equal? indices indices_)
+                       x
+                       (+ x dh)))
+                x))))
+     xs))
+  (lambda xs
+    ;; We need the output shape and the input shape along the
+    ;; differentiated axis.
+    (let ((fxs (apply f xs))
+         (x (list-ref xs axis)))
+      (cond
+       ((and (number? fxs)
+            (number? x))
+       (/ (- (apply f (axis-add xs step))
+             (apply f (axis-add xs (- step))))
+          (* 2 step)))
+       ((and (number? fxs)
+            (array? x))
+       (apply
+        produce-array
+        (lambda indices
+          (/ (- (apply f (apply axis-add xs step indices))
+                (apply f (apply axis-add xs (- step) indices)))
+             (* 2 step)))
+        (array-dimensions x)))
+       ((and (array? fxs)
+            (number? x))
+       ((v:extend /)
+        ((v:extend -)
+         (apply f (axis-add xs step))
+         (apply f (axis-add xs (- step))))
+        (* 2 step)))
+       ((and (array? fxs)
+            (array? x))
+       (let ((a (apply
+                 make-array
+                 0
+                 (append (array-dimensions fxs)
+                         (array-dimensions x)))))
+         (for-indices-in-range
+          (lambda indices-in
+            (let ((dfxs ((v:extend /)
+                         ((v:extend -)
+                          (apply f (apply axis-add xs step indices-in))
+                          (apply f (apply axis-add xs (- step) indices-in)))
+                         (* 2 step))))
+              (for-indices-in-range
+               (lambda indices-out
+                 (apply
+                  array-set!
+                  a
+                  (apply array-ref dfxs indices-out)
+                  (append indices-out indices-in)))
+               (list-zeros (array-rank fxs))
+               (array-dimensions fxs))))
+          (list-zeros (array-rank x))
+          (array-dimensions x))
+         a))))))
+
+(test-begin "grad")
+
+;; extended operations
+(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity))
+(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp))
+(test-assert
+    (receive (gx gy) (random-shared)
+      (with-generators (gx gy) ~ (v:extend *) v:*)))
+
+(define* (apply-grad grad-func #:optional (axis 0))
+  (lambda (f . args)
+    (apply (grad-func f axis) args)))
+
+(test-assert
+    (with-generators
+     (random-func1 random-input)
+     ~ (apply-grad ngrad) (apply-grad v:grad)))
+
+;; `v:mean' only takes non-empty arrays so we treat it separately
+(test-assert
+    (with-generators
+     ((differentiable-func-generator (list v:mean) random-non-empty-array)
+      random-non-empty-array)
+     ~ (apply-grad ngrad) (apply-grad v:grad)))
+
+(test-assert
+    (receive (gx gy) (random-shared)
+      (with-generators
+       (random-func2 gx gy)
+       ~ (apply-grad ngrad 0) (apply-grad v:grad 0))))
+(test-assert
+    (receive (gx gy) (random-shared)
+      (with-generators
+       (random-func2 gx gy)
+       ~ (apply-grad ngrad 1) (apply-grad v:grad 1))))
+
+;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it
+;; separately
+(define* (apply-grad-amap2 grad-func #:optional (axis 0))
+  (lambda (f x y)
+    ((grad-func
+      (lambda (x y)
+       (v:amap2 f x y))
+      axis)
+     x y)))
+(define random-func2-rank&dims>0
+  (receive (gx gy) (random-shared-array-rank&dims>0)
+    (differentiable-func-generator f2s gx gy)))
+(test-assert
+    (receive (gx gy) (random-shared-array-rank&dims>0)
+      (with-generators
+       (random-func2-rank&dims>0 gx gy)
+       ~ (apply-grad-amap2 ngrad 0) (apply-grad-amap2 v:grad 0))))
+(test-assert
+    (receive (gx gy) (random-shared-array-rank&dims>0)
+      (with-generators
+       (random-func2-rank&dims>0 gx gy)
+       ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1))))
+
+;; `v:adot'
+(define (random-shared-contractible)
+  "Returns three generators: the first two generate arrays that are contractible
+according to the number generated by the third one."
+  (let* ((n (random 5))
+        (sa (random-array-shape n))
+        (sb (append (reverse (take (reverse sa)
+                                   n))
+                    (random-array-shape 0 (- 5 n)))))
+    (values
+     (lambda ()
+       (random-array sa))
+     (lambda ()
+       (random-array sb))
+     (lambda ()
+       (let ((tmp n))
+        (set! n (random 5))
+        (set! sa (random-array-shape n))
+        (set! sb (append (reverse (take (reverse sa)
+                                         n))
+                         (random-array-shape 0 (- 5 n))))
+        tmp)))))
+(test-assert
+    (receive (gx gy gz) (random-shared-contractible)
+      (with-generators
+       (gx gy gz)
+       ~
+       (lambda (a b n) ((ngrad v:adot 0) a b n))
+       (lambda (a b n) ((v:grad v:adot 0) a b n)))))
+
+;; chain rule
+(test-assert
+    (with-generators
+     (random-func1 random-func1 random-input)
+     ~
+     (lambda (f g x)
+       (let* ((gx (g x))
+             (r
+              (v:dot ((v:grad f) gx)
+                     ((v:grad g) x)
+                     (v:rank-of gx))))
+        (ifn (and (number? (f (g x)))
+                  (number? x)
+                  (array? r))
+             r
+             (array-ref r))))
+     (lambda (f g x)
+       ((v:grad (compose f g)) x))))
+
+(test-end "grad")
diff --git a/grad.scm b/grad.scm
new file mode 100644 (file)
index 0000000..694aa55
--- /dev/null
+++ b/grad.scm
@@ -0,0 +1,413 @@
+(define-module (vouivre grad)
+  #:use-module (srfi srfi-1)
+  #:use-module (srfi srfi-9)
+  #:use-module (vouivre misc)
+  #:export
+  (adot
+   amap2
+   differentiable-wrapper
+   dot
+   extend
+   grad
+   internal-jacobian
+   mean
+   mirror
+   one
+   rank-of)
+  #:replace
+  ((i:* . *)
+   (i:- . -)
+   (i:exp . exp)
+   (i:fold . fold)
+   (i:identity . identity)
+   (i:max . max)))
+
+;;;; misc utilities
+
+(define (one . args)
+  1)
+
+;;;; array utilities
+
+(define (contracted-dims a b n)
+  (let ((dims-a (array-dimensions a))
+        (dims-b (array-dimensions b)))
+    (if (or (> n (array-rank a))
+            (> n (array-rank b)))
+      (error "can't contract arrays with size lower than" n)
+      (append (reverse (drop (reverse dims-a)
+                             n))
+              (drop dims-b n)))))
+
+(define (contract-arrays a b n)
+  (let* ((dims (contracted-dims a b n))
+         (dims-b (array-dimensions b))
+         (nb-dims-a (array-rank a))
+         (nb-dims-b (array-rank b))
+         (nb-fix-dims-a (- nb-dims-a n))
+         (r (apply make-array 0 dims)))
+    (for-indices-in-range
+      (lambda r-indices
+        (apply
+          array-set!
+          r
+          (let ((s 0))
+            (for-indices-in-range
+              (lambda free-indices
+                (set! s (+ s (* (apply
+                                  array-ref
+                                  a
+                                  (append (take r-indices nb-fix-dims-a)
+                                          free-indices))
+                                (apply
+                                  array-ref
+                                  b
+                                  (append free-indices
+                                          (drop r-indices nb-fix-dims-a)))))))
+              (list-zeros n)
+              (take dims-b n))
+            s)
+          r-indices))
+      (list-zeros (length dims))
+      dims)
+    r))
+
+;;;; utilities that work on both numbers and arrays
+
+(define (extend f)
+  "Extend a function of one or more scalars to apply to numbers/arrays
+element-wise. All arrays must have the same dimension."
+  (define (apply-elemwise f indices args)
+    (apply f (map (lambda (x)
+                   (if (number? x)
+                       x
+                       (apply array-ref x indices)))
+                 args)))
+  (lambda xs
+    (if-let (x (find array? xs))
+           (apply
+            produce-array
+            (lambda is
+              (apply-elemwise f is xs))
+            (array-dimensions x))
+           (apply f xs))))
+
+(define (dot x y n)
+  (cond
+   ((and (number? x) (number? y))
+    (* x y))
+   ((and (array? x) (array? y))
+    (contract-arrays x y n))
+   ((and (array? x) (number? y))
+    ((extend *) x y))
+   ((and (number? x) (array? y))
+    ((extend *) x y))
+   (else (error "can't dot because of invalid types or ranks" x y n))))
+
+(define (mirror f x . args)
+  "Create a scalar/array of shape [x]:[x] where element with index `is:is' has
+the value of `f' evaluated at the `is' elements of `args' and all other elements
+being zero."
+  (cond
+   ((number? x)
+    (apply f args))
+   ((array? x)
+    (let ((n (array-rank x))
+         (dims (array-dimensions x)))
+      (apply
+       produce-array
+       (lambda indices
+        (if (equal? (take indices n)
+                    (drop indices n))
+            (apply f (map (lambda (arg)
+                            (apply array-ref arg (drop indices n)))
+                          args))
+            0))
+       (append dims dims))))
+   (else (error "expected array or number, got " x))))
+
+(define (rank-of x)
+  (if (number? x)
+      0
+      (array-rank x)))
+
+(define (zeros-out:in x y)
+  "Zeros in the shape of [x]:[y]."
+  (cond
+   ((and (number? x)
+        (number? y))
+    0)
+   ((and (array? x)
+        (array? y))
+    (apply make-array 0 (append (array-dimensions x)
+                               (array-dimensions y))))
+   ((and (number? x)
+        (array? y))
+    (apply make-array 0 (array-dimensions y)))
+   ((and (array? x)
+        (number? y))
+    (apply make-array 0 (array-dimensions x)))
+   (else (error "undefined." x y))))
+
+;;;; differentiation
+
+(define-record-type internal
+  (make-internal jacobian forward)
+  internal?
+  (jacobian internal-jacobian set-internal-jacobian!)
+  (forward internal-forward))
+
+(define *grad* (make-parameter #f))
+
+(define* (grad f #:optional (axis 0))
+  (define (wrap x i)
+    (if (= i axis)
+       (make-internal (mirror one x) x)
+       x))
+  (lambda xs
+    ((if (*grad*) identity internal-jacobian)
+     (let ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity))))
+       (parameterize (((@@ (vouivre grad) *grad*) (list-ref wrapped-xs axis)))
+        (apply f wrapped-xs))))))
+
+(define (grad-input)
+  (internal-forward (*grad*)))
+
+(define (differentiable-wrapper jacobian-generator function input . more)
+  ;; NOTE: Both the jacobian generator and the function act on naked inputs
+  ;;       (numbers or arrays not inside an internal object). The generator
+  ;;       returns a list of jacobians, each one expressing the change of the
+  ;;       function's output when changing the corresponding input. In cases
+  ;;       where an argument isn't meant to be differentiable its corresponding
+  ;;       element in the list should be `#f'.
+  (let* ((inputs (cons input more))
+        (naked-inputs
+         (map
+          (lambda (x)
+            (if (internal? x)
+                (internal-forward x)
+                x))
+          inputs))
+        (output (apply function naked-inputs)))
+    (if (*grad*)
+       (make-internal
+        (or
+         (fold
+          (lambda (Jx x prev)
+            (if (not (internal? x))
+                prev
+                ((if (not prev)
+                     identity
+                     (lambda (x) ((extend +) prev x)))
+                 (dot Jx
+                      (internal-jacobian x)
+                      (rank-of (internal-forward x))))))
+          #f
+          (apply jacobian-generator naked-inputs)
+          inputs)
+         (zeros-out:in output (grad-input)))
+        output)
+       output)))
+
+(define (i:identity x)
+  "Differentiable identity."
+  (differentiable-wrapper
+   (lambda (x) (list (mirror one x)))
+   identity
+   x))
+
+(define (i:exp x)
+  "Differentiable exponential."
+  (differentiable-wrapper
+   (lambda (x) (list (mirror exp x x)))
+   (extend exp)
+   x))
+
+(define (i:* x y)
+  "Differentiable element-wise multiplication."
+  (differentiable-wrapper
+   (lambda (x y)
+     (list
+      (cond
+       ((and (number? y) (number? x))
+       y)
+       ((and (number? y) (array? x))
+       (mirror (lambda _ y) x))
+       ((and (array? y) (number? x))
+       y)
+       (else
+       (mirror identity x y)))
+      (cond
+       ((and (number? x) (number? y))
+       x)
+       ((and (number? x) (array? y))
+       (mirror (lambda _ x) y))
+       ((and (array? x) (number? y))
+       x)
+       (else
+       (mirror identity y x)))))
+   (extend *)
+   x y))
+
+(define (i:- x y)
+  "Differentiable element-wise subtraction."
+  (differentiable-wrapper
+   (lambda (x y)
+     (list
+      (cond
+       ((and (number? y) (number? x))
+       +1)
+       ((and (number? y) (array? x))
+       (mirror one x))
+       ((and (array? y) (number? x))
+       (apply make-array 1 (array-dimensions y)))
+       (else
+       (mirror one x)))
+      (cond
+       ((and (number? x) (number? y))
+       -1)
+       ((and (number? x) (array? y))
+       (mirror (lambda _ -1) y))
+       ((and (array? x) (number? y))
+       (apply make-array -1 (array-dimensions x)))
+       (else
+       (mirror (lambda _ -1) y)))))
+   (extend -)
+   x y))
+
+(define (i:max x y)
+  "Differentiable element-wise maximum."
+  (define (dmax x y)
+    (cond
+     ((> x y)
+      1)
+     ((= x y)
+      1/2)
+     (else
+      0)))
+  (differentiable-wrapper
+   (lambda (x y)
+     (list
+      (cond
+       ((and (number? y) (number? x))
+       (dmax x y))
+       ((and (number? y) (array? x))
+       (mirror (lambda (x) (dmax x y)) x x))
+       ((and (array? y) (number? x))
+       (array-map (lambda (y) (dmax x y)) y))
+       (else
+       (mirror (lambda (x y) (dmax x y)) x x y)))
+      (cond
+       ((and (number? x) (number? y))
+       (dmax y x))
+       ((and (number? x) (array? y))
+       (mirror (lambda (y) (dmax y x)) y y))
+       ((and (array? x) (number? y))
+       (array-map (lambda (x) (dmax y x)) x))
+       (else
+       (mirror (lambda (y x) (dmax y x)) y y x)))))
+   (extend max)
+   x y))
+
+(define (mean x)
+  "Differentiable mean on arrays."
+  (differentiable-wrapper
+   (lambda (x)
+     (let ((dims (array-dimensions x)))
+       (list
+       (apply
+        make-array
+        (/ 1 (apply * dims))
+        dims))))
+   (lambda (x)
+     (let ((sum 0)
+          (count 0))
+       (array-for-each
+       (lambda (x)
+         (set! sum (+ sum x))
+         (set! count (1+ count)))
+       x)
+       (/ sum count)))
+   x))
+
+(define (amap2 f x y)
+  (define (unbox-with proc x)
+    (ifn (internal? x)
+        x
+        (proc x)))
+  (define (boxed-ref x i)
+    (ifn (internal? x)
+        (array-cell-ref x i)
+        (make-internal (array-cell-ref (internal-jacobian x)
+                                       i)
+                       (array-cell-ref (internal-forward x)
+                                       i))))
+  (define (boxed-fi f i x y)
+    (f (boxed-ref x i)
+       (boxed-ref y i)))
+  (define (dims-of x)
+    (if (number? x)
+       '()
+       (array-dimensions x)))
+  (if (internal? f)
+      (error "unsupported application of `amap2' to internal object" f)
+      (let ((bs (first (array-dimensions (unbox-with internal-forward x))))
+           (f0 (boxed-fi f 0 x y)))
+       (ifn (internal? f0)
+            (let ((fwd (apply make-array 0 bs (dims-of f0))))
+              (for-indices-in-range
+               (lambda (i)
+                 (array-cell-set! fwd (boxed-fi f i x y) i))
+               '(0) (list bs))
+              fwd)
+            (let ((jac (apply make-array 0 bs (dims-of (internal-jacobian f0))))
+                  (fwd (apply make-array 0 bs (dims-of (internal-forward f0)))))
+              (for-indices-in-range
+               (lambda (i)
+                 (let ((fi (boxed-fi f i x y)))
+                   (array-cell-set! jac (internal-jacobian fi) i)
+                   (array-cell-set! fwd (internal-forward fi) i)))
+               '(0) (list bs))
+              (make-internal jac fwd))))))
+
+(define (adot x y n)
+  (differentiable-wrapper
+   (lambda (x y n)
+     (let ((bound-dims (contracted-dims x y n)))
+       (list
+       (let ((r (apply make-array 0 (append bound-dims (array-dimensions x)))))
+         (for-indices-in-range
+          (lambda free-indices
+            (let ((free-indices-x (take free-indices (- (array-rank x) n)))
+                  (free-indices-y (drop free-indices (- (array-rank x) n))))
+              (for-indices-in-range
+               (lambda bound-indices
+                 (apply
+                  array-set!
+                  r
+                  (apply array-ref y (append bound-indices free-indices-y))
+                  (append free-indices free-indices-x bound-indices)))
+               (list-zeros n)
+               (take (array-dimensions y) n))))
+          (list-zeros (length bound-dims))
+          bound-dims)
+         r)
+       (let ((r (apply make-array 0 (append bound-dims (array-dimensions y)))))
+         (for-indices-in-range
+          (lambda free-indices
+            (let ((free-indices-x (take free-indices (- (array-rank x) n)))
+                  (free-indices-y (drop free-indices (- (array-rank x) n))))
+              (for-indices-in-range
+               (lambda bound-indices
+                 (apply
+                  array-set!
+                  r
+                  (apply array-ref x (append free-indices-x bound-indices))
+                  (append free-indices bound-indices free-indices-y)))
+               (list-zeros n)
+               (take (array-dimensions y) n))))
+          (list-zeros (length bound-dims))
+          bound-dims)
+         r)
+       #f)))
+   contract-arrays x y n))