From 15068b31bba5d0985e5e881a5683e9e853f7a6f2 Mon Sep 17 00:00:00 2001 From: admin Date: Wed, 1 Nov 2023 19:54:19 +0900 Subject: [PATCH] Implement an automatic differentiation engine --- grad-tests.scm | 325 ++++++++++++++++++++++++++++++++++++++ grad.scm | 413 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 738 insertions(+) create mode 100644 grad-tests.scm create mode 100644 grad.scm diff --git a/grad-tests.scm b/grad-tests.scm new file mode 100644 index 0000000..4fb7a24 --- /dev/null +++ b/grad-tests.scm @@ -0,0 +1,325 @@ +(define-module (vouivre grad tests) + #:use-module ((vouivre grad) #:prefix v:) + #:use-module (ice-9 receive) + #:use-module (srfi srfi-1) + #:use-module (srfi srfi-64) + #:use-module (vouivre misc) + #:export + (apply-grad + apply-grad-amap2 + a~ + differentiable-func-generator + lambda-const-call + ngrad + n~ + random-array + random-array-shape + random-func1 + random-func2 + random-input + random-list-element + random-non-empty-array + random-shape + random-shared + random-shared-array-rank&dims>0 + with-generators + ~)) + +(define f1s (list v:exp v:identity)) +(define f2s (list v:* v:- v:max)) + +(define-syntax-rule (with-generators (g1 g2 ...) equal expected given) + (let ((times 100) + (fx expected) + (fy given) + (generators (list g1 g2 ...))) + (call/cc + (lambda (break) + (do ((i 0 (1+ i))) + ((= i times) #t) + (let ((zs (map-in-order (lambda (g) (g)) generators))) + (let ((r1 (apply fx zs)) + (r2 (apply fy zs))) + (unless (equal r1 r2) + (break #f zs r1 r2))))))))) + +(define (lambda-const-call f . consts) + (lambda _ + (apply f consts))) + +(define* (random-array-shape + #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5)) + (list-tabulate (+ min-rank (random (- max-rank min-rank))) + (lambda _ (+ min-dim (random (- max-dim min-dim)))))) + +(define (random-shape) + (if (= 0 (random 2)) + 0 + (random-array-shape))) + +(define* (random-array #:optional shape) + (let ((a (apply make-array 0 (or shape (random-array-shape))))) + (array-map! a random:uniform) + a)) + +(define (random-non-empty-array) + "Random array of at least one element." + (random-array (random-array-shape 0 5 1 5))) + +(define* (random-input #:optional shape) + (let ((shape (or shape (random-shape)))) + (if (eq? 0 shape) + (random:uniform) + (random-array shape)))) + +(define (random-shared) + (let ((shape (random-shape))) + (values + (lambda () + (random-input shape)) + (lambda () + (let ((x (random-input + (random-list-element + (list 0 (if (list? shape) + shape + (random-shape))))))) + (set! shape (random-shape)) + x))))) + +(define (random-shared-array-rank&dims>0) + (let ((shape (random-array-shape 1 5 1 5))) + (values + (lambda () + (random-array shape)) + (lambda () + (let ((x (random-array shape))) + (set! shape (random-array-shape 1 5 1 5)) + x))))) + +(define (random-list-element lst) + (list-ref lst (random (length lst)))) + +(define (differentiable-func-generator lst . input-generators) + (lambda () + (random-list-element + (cons + (apply + lambda-const-call + (random-list-element lst) + (map (lambda (g) (g)) + input-generators)) + lst)))) + +(define random-func1 + (differentiable-func-generator f1s random-input)) +(define random-func2 + (receive (gx gy) (random-shared) + (differentiable-func-generator f2s gx gy))) + +(define* (n~ x y #:optional (error 1e-4)) + (and + (>= y (- x error)) + (<= y (+ x error)))) + +(define* (a~ x y #:optional (error 1e-4)) + (and + (equal? (array-dimensions x) + (array-dimensions y)) + (call/cc + (lambda (break) + (array-for-each + (lambda (x y) + (unless (~ x y error) + (break #f))) + x y) + #t)))) + +(define* (~ x y #:optional (error 1e-4)) + (cond + ((and (number? x) (number? y)) + (n~ x y error)) + ((and (array? x) (array? y)) + (a~ x y error)) + (else #f))) + +(define* (ngrad f #:optional (axis 0) (step 1e-6)) + "Gradient using a numerical centered difference approximation." + (define (axis-add xs dh . indices) + "Add `dh' to the number or array at the given `axis' of `xs', +and, when it's an array, at the given index." + (map-indexed + (lambda (x i) + (ifn (= i axis) + x + (if (number? x) + (+ x dh) + (array-map-indexed + (lambda (x . indices_) + (ifn (equal? indices indices_) + x + (+ x dh))) + x)))) + xs)) + (lambda xs + ;; We need the output shape and the input shape along the + ;; differentiated axis. + (let ((fxs (apply f xs)) + (x (list-ref xs axis))) + (cond + ((and (number? fxs) + (number? x)) + (/ (- (apply f (axis-add xs step)) + (apply f (axis-add xs (- step)))) + (* 2 step))) + ((and (number? fxs) + (array? x)) + (apply + produce-array + (lambda indices + (/ (- (apply f (apply axis-add xs step indices)) + (apply f (apply axis-add xs (- step) indices))) + (* 2 step))) + (array-dimensions x))) + ((and (array? fxs) + (number? x)) + ((v:extend /) + ((v:extend -) + (apply f (axis-add xs step)) + (apply f (axis-add xs (- step)))) + (* 2 step))) + ((and (array? fxs) + (array? x)) + (let ((a (apply + make-array + 0 + (append (array-dimensions fxs) + (array-dimensions x))))) + (for-indices-in-range + (lambda indices-in + (let ((dfxs ((v:extend /) + ((v:extend -) + (apply f (apply axis-add xs step indices-in)) + (apply f (apply axis-add xs (- step) indices-in))) + (* 2 step)))) + (for-indices-in-range + (lambda indices-out + (apply + array-set! + a + (apply array-ref dfxs indices-out) + (append indices-out indices-in))) + (list-zeros (array-rank fxs)) + (array-dimensions fxs)))) + (list-zeros (array-rank x)) + (array-dimensions x)) + a)))))) + +(test-begin "grad") + +;; extended operations +(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity)) +(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp)) +(test-assert + (receive (gx gy) (random-shared) + (with-generators (gx gy) ~ (v:extend *) v:*))) + +(define* (apply-grad grad-func #:optional (axis 0)) + (lambda (f . args) + (apply (grad-func f axis) args))) + +(test-assert + (with-generators + (random-func1 random-input) + ~ (apply-grad ngrad) (apply-grad v:grad))) + +;; `v:mean' only takes non-empty arrays so we treat it separately +(test-assert + (with-generators + ((differentiable-func-generator (list v:mean) random-non-empty-array) + random-non-empty-array) + ~ (apply-grad ngrad) (apply-grad v:grad))) + +(test-assert + (receive (gx gy) (random-shared) + (with-generators + (random-func2 gx gy) + ~ (apply-grad ngrad 0) (apply-grad v:grad 0)))) +(test-assert + (receive (gx gy) (random-shared) + (with-generators + (random-func2 gx gy) + ~ (apply-grad ngrad 1) (apply-grad v:grad 1)))) + +;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it +;; separately +(define* (apply-grad-amap2 grad-func #:optional (axis 0)) + (lambda (f x y) + ((grad-func + (lambda (x y) + (v:amap2 f x y)) + axis) + x y))) +(define random-func2-rank&dims>0 + (receive (gx gy) (random-shared-array-rank&dims>0) + (differentiable-func-generator f2s gx gy))) +(test-assert + (receive (gx gy) (random-shared-array-rank&dims>0) + (with-generators + (random-func2-rank&dims>0 gx gy) + ~ (apply-grad-amap2 ngrad 0) (apply-grad-amap2 v:grad 0)))) +(test-assert + (receive (gx gy) (random-shared-array-rank&dims>0) + (with-generators + (random-func2-rank&dims>0 gx gy) + ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1)))) + +;; `v:adot' +(define (random-shared-contractible) + "Returns three generators: the first two generate arrays that are contractible +according to the number generated by the third one." + (let* ((n (random 5)) + (sa (random-array-shape n)) + (sb (append (reverse (take (reverse sa) + n)) + (random-array-shape 0 (- 5 n))))) + (values + (lambda () + (random-array sa)) + (lambda () + (random-array sb)) + (lambda () + (let ((tmp n)) + (set! n (random 5)) + (set! sa (random-array-shape n)) + (set! sb (append (reverse (take (reverse sa) + n)) + (random-array-shape 0 (- 5 n)))) + tmp))))) +(test-assert + (receive (gx gy gz) (random-shared-contractible) + (with-generators + (gx gy gz) + ~ + (lambda (a b n) ((ngrad v:adot 0) a b n)) + (lambda (a b n) ((v:grad v:adot 0) a b n))))) + +;; chain rule +(test-assert + (with-generators + (random-func1 random-func1 random-input) + ~ + (lambda (f g x) + (let* ((gx (g x)) + (r + (v:dot ((v:grad f) gx) + ((v:grad g) x) + (v:rank-of gx)))) + (ifn (and (number? (f (g x))) + (number? x) + (array? r)) + r + (array-ref r)))) + (lambda (f g x) + ((v:grad (compose f g)) x)))) + +(test-end "grad") diff --git a/grad.scm b/grad.scm new file mode 100644 index 0000000..694aa55 --- /dev/null +++ b/grad.scm @@ -0,0 +1,413 @@ +(define-module (vouivre grad) + #:use-module (srfi srfi-1) + #:use-module (srfi srfi-9) + #:use-module (vouivre misc) + #:export + (adot + amap2 + differentiable-wrapper + dot + extend + grad + internal-jacobian + mean + mirror + one + rank-of) + #:replace + ((i:* . *) + (i:- . -) + (i:exp . exp) + (i:fold . fold) + (i:identity . identity) + (i:max . max))) + +;;;; misc utilities + +(define (one . args) + 1) + +;;;; array utilities + +(define (contracted-dims a b n) + (let ((dims-a (array-dimensions a)) + (dims-b (array-dimensions b))) + (if (or (> n (array-rank a)) + (> n (array-rank b))) + (error "can't contract arrays with size lower than" n) + (append (reverse (drop (reverse dims-a) + n)) + (drop dims-b n))))) + +(define (contract-arrays a b n) + (let* ((dims (contracted-dims a b n)) + (dims-b (array-dimensions b)) + (nb-dims-a (array-rank a)) + (nb-dims-b (array-rank b)) + (nb-fix-dims-a (- nb-dims-a n)) + (r (apply make-array 0 dims))) + (for-indices-in-range + (lambda r-indices + (apply + array-set! + r + (let ((s 0)) + (for-indices-in-range + (lambda free-indices + (set! s (+ s (* (apply + array-ref + a + (append (take r-indices nb-fix-dims-a) + free-indices)) + (apply + array-ref + b + (append free-indices + (drop r-indices nb-fix-dims-a))))))) + (list-zeros n) + (take dims-b n)) + s) + r-indices)) + (list-zeros (length dims)) + dims) + r)) + +;;;; utilities that work on both numbers and arrays + +(define (extend f) + "Extend a function of one or more scalars to apply to numbers/arrays +element-wise. All arrays must have the same dimension." + (define (apply-elemwise f indices args) + (apply f (map (lambda (x) + (if (number? x) + x + (apply array-ref x indices))) + args))) + (lambda xs + (if-let (x (find array? xs)) + (apply + produce-array + (lambda is + (apply-elemwise f is xs)) + (array-dimensions x)) + (apply f xs)))) + +(define (dot x y n) + (cond + ((and (number? x) (number? y)) + (* x y)) + ((and (array? x) (array? y)) + (contract-arrays x y n)) + ((and (array? x) (number? y)) + ((extend *) x y)) + ((and (number? x) (array? y)) + ((extend *) x y)) + (else (error "can't dot because of invalid types or ranks" x y n)))) + +(define (mirror f x . args) + "Create a scalar/array of shape [x]:[x] where element with index `is:is' has +the value of `f' evaluated at the `is' elements of `args' and all other elements +being zero." + (cond + ((number? x) + (apply f args)) + ((array? x) + (let ((n (array-rank x)) + (dims (array-dimensions x))) + (apply + produce-array + (lambda indices + (if (equal? (take indices n) + (drop indices n)) + (apply f (map (lambda (arg) + (apply array-ref arg (drop indices n))) + args)) + 0)) + (append dims dims)))) + (else (error "expected array or number, got " x)))) + +(define (rank-of x) + (if (number? x) + 0 + (array-rank x))) + +(define (zeros-out:in x y) + "Zeros in the shape of [x]:[y]." + (cond + ((and (number? x) + (number? y)) + 0) + ((and (array? x) + (array? y)) + (apply make-array 0 (append (array-dimensions x) + (array-dimensions y)))) + ((and (number? x) + (array? y)) + (apply make-array 0 (array-dimensions y))) + ((and (array? x) + (number? y)) + (apply make-array 0 (array-dimensions x))) + (else (error "undefined." x y)))) + +;;;; differentiation + +(define-record-type internal + (make-internal jacobian forward) + internal? + (jacobian internal-jacobian set-internal-jacobian!) + (forward internal-forward)) + +(define *grad* (make-parameter #f)) + +(define* (grad f #:optional (axis 0)) + (define (wrap x i) + (if (= i axis) + (make-internal (mirror one x) x) + x)) + (lambda xs + ((if (*grad*) identity internal-jacobian) + (let ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity)))) + (parameterize (((@@ (vouivre grad) *grad*) (list-ref wrapped-xs axis))) + (apply f wrapped-xs)))))) + +(define (grad-input) + (internal-forward (*grad*))) + +(define (differentiable-wrapper jacobian-generator function input . more) + ;; NOTE: Both the jacobian generator and the function act on naked inputs + ;; (numbers or arrays not inside an internal object). The generator + ;; returns a list of jacobians, each one expressing the change of the + ;; function's output when changing the corresponding input. In cases + ;; where an argument isn't meant to be differentiable its corresponding + ;; element in the list should be `#f'. + (let* ((inputs (cons input more)) + (naked-inputs + (map + (lambda (x) + (if (internal? x) + (internal-forward x) + x)) + inputs)) + (output (apply function naked-inputs))) + (if (*grad*) + (make-internal + (or + (fold + (lambda (Jx x prev) + (if (not (internal? x)) + prev + ((if (not prev) + identity + (lambda (x) ((extend +) prev x))) + (dot Jx + (internal-jacobian x) + (rank-of (internal-forward x)))))) + #f + (apply jacobian-generator naked-inputs) + inputs) + (zeros-out:in output (grad-input))) + output) + output))) + +(define (i:identity x) + "Differentiable identity." + (differentiable-wrapper + (lambda (x) (list (mirror one x))) + identity + x)) + +(define (i:exp x) + "Differentiable exponential." + (differentiable-wrapper + (lambda (x) (list (mirror exp x x))) + (extend exp) + x)) + +(define (i:* x y) + "Differentiable element-wise multiplication." + (differentiable-wrapper + (lambda (x y) + (list + (cond + ((and (number? y) (number? x)) + y) + ((and (number? y) (array? x)) + (mirror (lambda _ y) x)) + ((and (array? y) (number? x)) + y) + (else + (mirror identity x y))) + (cond + ((and (number? x) (number? y)) + x) + ((and (number? x) (array? y)) + (mirror (lambda _ x) y)) + ((and (array? x) (number? y)) + x) + (else + (mirror identity y x))))) + (extend *) + x y)) + +(define (i:- x y) + "Differentiable element-wise subtraction." + (differentiable-wrapper + (lambda (x y) + (list + (cond + ((and (number? y) (number? x)) + +1) + ((and (number? y) (array? x)) + (mirror one x)) + ((and (array? y) (number? x)) + (apply make-array 1 (array-dimensions y))) + (else + (mirror one x))) + (cond + ((and (number? x) (number? y)) + -1) + ((and (number? x) (array? y)) + (mirror (lambda _ -1) y)) + ((and (array? x) (number? y)) + (apply make-array -1 (array-dimensions x))) + (else + (mirror (lambda _ -1) y))))) + (extend -) + x y)) + +(define (i:max x y) + "Differentiable element-wise maximum." + (define (dmax x y) + (cond + ((> x y) + 1) + ((= x y) + 1/2) + (else + 0))) + (differentiable-wrapper + (lambda (x y) + (list + (cond + ((and (number? y) (number? x)) + (dmax x y)) + ((and (number? y) (array? x)) + (mirror (lambda (x) (dmax x y)) x x)) + ((and (array? y) (number? x)) + (array-map (lambda (y) (dmax x y)) y)) + (else + (mirror (lambda (x y) (dmax x y)) x x y))) + (cond + ((and (number? x) (number? y)) + (dmax y x)) + ((and (number? x) (array? y)) + (mirror (lambda (y) (dmax y x)) y y)) + ((and (array? x) (number? y)) + (array-map (lambda (x) (dmax y x)) x)) + (else + (mirror (lambda (y x) (dmax y x)) y y x))))) + (extend max) + x y)) + +(define (mean x) + "Differentiable mean on arrays." + (differentiable-wrapper + (lambda (x) + (let ((dims (array-dimensions x))) + (list + (apply + make-array + (/ 1 (apply * dims)) + dims)))) + (lambda (x) + (let ((sum 0) + (count 0)) + (array-for-each + (lambda (x) + (set! sum (+ sum x)) + (set! count (1+ count))) + x) + (/ sum count))) + x)) + +(define (amap2 f x y) + (define (unbox-with proc x) + (ifn (internal? x) + x + (proc x))) + (define (boxed-ref x i) + (ifn (internal? x) + (array-cell-ref x i) + (make-internal (array-cell-ref (internal-jacobian x) + i) + (array-cell-ref (internal-forward x) + i)))) + (define (boxed-fi f i x y) + (f (boxed-ref x i) + (boxed-ref y i))) + (define (dims-of x) + (if (number? x) + '() + (array-dimensions x))) + (if (internal? f) + (error "unsupported application of `amap2' to internal object" f) + (let ((bs (first (array-dimensions (unbox-with internal-forward x)))) + (f0 (boxed-fi f 0 x y))) + (ifn (internal? f0) + (let ((fwd (apply make-array 0 bs (dims-of f0)))) + (for-indices-in-range + (lambda (i) + (array-cell-set! fwd (boxed-fi f i x y) i)) + '(0) (list bs)) + fwd) + (let ((jac (apply make-array 0 bs (dims-of (internal-jacobian f0)))) + (fwd (apply make-array 0 bs (dims-of (internal-forward f0))))) + (for-indices-in-range + (lambda (i) + (let ((fi (boxed-fi f i x y))) + (array-cell-set! jac (internal-jacobian fi) i) + (array-cell-set! fwd (internal-forward fi) i))) + '(0) (list bs)) + (make-internal jac fwd)))))) + +(define (adot x y n) + (differentiable-wrapper + (lambda (x y n) + (let ((bound-dims (contracted-dims x y n))) + (list + (let ((r (apply make-array 0 (append bound-dims (array-dimensions x))))) + (for-indices-in-range + (lambda free-indices + (let ((free-indices-x (take free-indices (- (array-rank x) n))) + (free-indices-y (drop free-indices (- (array-rank x) n)))) + (for-indices-in-range + (lambda bound-indices + (apply + array-set! + r + (apply array-ref y (append bound-indices free-indices-y)) + (append free-indices free-indices-x bound-indices))) + (list-zeros n) + (take (array-dimensions y) n)))) + (list-zeros (length bound-dims)) + bound-dims) + r) + (let ((r (apply make-array 0 (append bound-dims (array-dimensions y))))) + (for-indices-in-range + (lambda free-indices + (let ((free-indices-x (take free-indices (- (array-rank x) n))) + (free-indices-y (drop free-indices (- (array-rank x) n)))) + (for-indices-in-range + (lambda bound-indices + (apply + array-set! + r + (apply array-ref x (append free-indices-x bound-indices)) + (append free-indices bound-indices free-indices-y))) + (list-zeros n) + (take (array-dimensions y) n)))) + (list-zeros (length bound-dims)) + bound-dims) + r) + #f))) + contract-arrays x y n)) -- 2.39.5