From ec5567a74aa49b25499cf1fd908d2ce51409aa28 Mon Sep 17 00:00:00 2001 From: admin Date: Thu, 23 Nov 2023 16:17:16 +0900 Subject: [PATCH] Implement reverse mode automatic differentiation --- grad-tests.scm => autodiff-tests.scm | 162 +++++++++--------- grad.scm => autodiff.scm | 244 ++++++++++++++++++++++----- 2 files changed, 285 insertions(+), 121 deletions(-) rename grad-tests.scm => autodiff-tests.scm (69%) rename grad.scm => autodiff.scm (72%) diff --git a/grad-tests.scm b/autodiff-tests.scm similarity index 69% rename from grad-tests.scm rename to autodiff-tests.scm index bdf8fd4..f1fc168 100644 --- a/grad-tests.scm +++ b/autodiff-tests.scm @@ -1,16 +1,16 @@ -(define-module (vouivre grad tests) - #:use-module ((vouivre grad) #:prefix v:) +(define-module (vouivre autodiff tests) + #:use-module ((vouivre autodiff) #:prefix v:) #:use-module (ice-9 receive) #:use-module (srfi srfi-1) #:use-module (srfi srfi-64) #:use-module (vouivre misc) #:export - (apply-grad - apply-grad-amap2 + (apply-diff a~ + const-generator differentiable-func-generator lambda-const-call - ngrad + ndiff n~ random-array random-array-shape @@ -27,14 +27,15 @@ with-generators ~)) -(define f1s (list v:exp v:identity)) +(define f1s (list v:abs v:cos v:exp v:identity v:sin)) (define f2s (list v:+ v:- v:* v:max v:min)) -(define-syntax-rule (with-generators (g1 g2 ...) equal expected given) +(define (with-generators% generators equal proc1 proc2 . more) + "Check that all procedures return the same value according to `equal' when +evaluated on arguments produced by the generators (the number of generators +being the number of arguments to each procedure." (let ((times 100) - (fx expected) - (fy given) - (generators (list g1 g2 ...))) + (procs (cons proc1 (cons proc2 more)))) (call/cc (lambda (break) (do ((i 0 (1+ i))) @@ -44,12 +45,16 @@ (lambda (e) (break #f zs)) (lambda () - (let ((r1 (apply fx zs)) - (r2 (apply fy zs))) - (unless (equal r1 r2) - (break #f zs r1 r2)))) + (let* ((rs (map (lambda (f) (apply f zs)) procs)) + (head (car rs))) + (unless (every (lambda (x) (equal x head)) + (cdr rs)) + (break #f zs rs)))) #:unwind? #t))))))) +(define-syntax-rule (with-generators (g1 g2 ...) equal expected given more ...) + (with-generators% (list g1 g2 ...) equal expected given more ...)) + (define (lambda-const-call f . consts) (lambda _ (apply f consts))) @@ -106,6 +111,10 @@ (define (random-list-element lst) (list-ref lst (random (length lst)))) +(define (const-generator generator) + (lambda () + generator)) + (define (differentiable-func-generator lst . input-generators) (lambda () (random-list-element @@ -129,17 +138,17 @@ (<= y (+ x error)))) (define* (a~ x y #:optional (error 1e-4)) - (and - (equal? (array-dimensions x) - (array-dimensions y)) - (call/cc - (lambda (break) - (array-for-each - (lambda (x y) - (unless (~ x y error) - (break #f))) - x y) - #t)))) + (and + (equal? (array-dimensions x) + (array-dimensions y)) + (call/cc + (lambda (break) + (array-for-each + (lambda (x y) + (unless (~ x y error) + (break #f))) + x y) + #t)))) (define* (~ x y #:optional (error 1e-4)) (cond @@ -149,8 +158,8 @@ (a~ x y error)) (else #f))) -(define* (ngrad f #:optional (axis 0) (step 1e-6)) - "Gradient using a numerical centered difference approximation." +(define* (ndiff f #:optional (axis 0) (step 1e-6)) + "Differentiation using a numerical centered difference approximation." (define (axis-add xs dh . indices) "Add `dh' to the number or array at the given `axis' of `xs', and, when it's an array, at the given index." @@ -221,72 +230,71 @@ and, when it's an array, at the given index." (array-dimensions x)) a)))))) -(test-begin "grad") +(define* (apply-diff differentiator #:optional (axis 0)) + "Apply a differentiator (`ndiff', `fdiff', `rdiff') to a function and its +arguments (this is a convenience function)." + (lambda (f . args) + (apply (differentiator f axis) args))) + +(test-begin "autodiff") -;; extended operations +;; not differentiating (test-assert (with-generators (random-input) ~ (v:extend identity) v:identity)) (test-assert (with-generators (random-input) ~ (v:extend exp) v:exp)) (test-assert (receive (gx gy) (random-shared) (with-generators (gx gy) ~ (v:extend *) v:*))) -(define* (apply-grad grad-func #:optional (axis 0)) - (lambda (f . args) - (apply (grad-func f axis) args))) - +;; differentiation in one variable (test-assert (with-generators (random-func1 random-input) - ~ (apply-grad ngrad) (apply-grad v:grad))) + ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff))) ;; `v:mean' only takes non-empty arrays so we treat it separately (test-assert (with-generators ((differentiable-func-generator (list v:mean) random-non-empty-array) random-non-empty-array) - ~ (apply-grad ngrad) (apply-grad v:grad))) + ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff))) +;; differentiation in two variables (test-assert (receive (gx gy) (random-shared) (with-generators (random-func2 gx gy) - ~ (apply-grad ngrad 0) (apply-grad v:grad 0)))) + ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0)))) (test-assert (receive (gx gy) (random-shared) (with-generators (random-func2 gx gy) - ~ (apply-grad ngrad 1) (apply-grad v:grad 1)))) + ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1)))) ;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it ;; separately -(define* (apply-grad-amap2 grad-func #:optional (axis 0)) - (lambda (f x y) - ((grad-func - (lambda (x y) - (v:amap2 f x y)) - axis) - x y))) (define random-func2-rank&dims>0 (receive (gx gy) (random-shared-array-rank&dims>0) (differentiable-func-generator f2s gx gy))) (test-assert (receive (gx gy) (random-shared-array-rank&dims>0) (with-generators - (random-func2-rank&dims>0 gx gy) - ~ (apply-grad-amap2 ngrad 0) (apply-grad-amap2 v:grad 0)))) + ((const-generator v:amap2) random-func2-rank&dims>0 gx gy) + ;; NOTE: for `v:amap2' the differentiable axes are 1 and 2. + ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1)))) (test-assert (receive (gx gy) (random-shared-array-rank&dims>0) (with-generators - (random-func2-rank&dims>0 gx gy) - ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1)))) -(test-assert - (~ ((ngrad v:amap2 1) v:* #(1 2 3) #(10 20 30)) - ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30)))) -(test-assert - (let ((x #(10 20 30)) - (y #(10 20 30))) - (~ ((ngrad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)) - ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))))) + ((const-generator v:amap2) random-func2-rank&dims>0 gx gy) + ~ (apply-diff ndiff 2) (apply-diff v:fdiff 2) (apply-diff v:rdiff 2)))) +(let* ((z #(1 2 3)) + (f (lambda (a) + (v:amap2 (lambda (x y) + (v:* a a)) + #(10 20 30) + #(40 50 60)))) + (e ((ndiff f) z))) + (test-assert (~ e ((v:fdiff f) z))) + (test-assert (~ e ((v:rdiff f) z)))) ;; `v:adot' (define (random-shared-contractible) @@ -307,41 +315,37 @@ according to the number generated by the third one." (set! n (random 5)) (set! sa (random-array-shape n)) (set! sb (append (reverse (take (reverse sa) - n)) + n)) (random-array-shape 0 (- 5 n)))) tmp))))) (test-assert (receive (gx gy gz) (random-shared-contractible) (with-generators - (gx gy gz) - ~ - (lambda (a b n) ((ngrad v:adot 0) a b n)) - (lambda (a b n) ((v:grad v:adot 0) a b n))))) + ((const-generator v:adot) gx gy gz) + ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0)))) (test-assert (receive (gx gy gz) (random-shared-contractible) (with-generators - (gx gy gz) - ~ - (lambda (a b n) ((ngrad v:adot 1) a b n)) - (lambda (a b n) ((v:grad v:adot 1) a b n))))) + ((const-generator v:adot) gx gy gz) + ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1)))) + +;; let binding re-entry +(test-assert + (with-generators + ((const-generator + (lambda (x) + (let ((c (v:maximum x))) + (v:+ c (v:- x c))))) + random-non-empty-array) + ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff))) ;; chain rule (test-assert (with-generators (random-func1 random-func1 random-input) ~ - (lambda (f g x) - (let* ((gx (g x)) - (r - (v:dot ((v:grad f) gx) - ((v:grad g) x) - (v:rank-of gx)))) - (ifn (and (number? (f (g x))) - (number? x) - (array? r)) - r - (array-ref r)))) - (lambda (f g x) - ((v:grad (compose f g)) x)))) + (lambda (f g x) ((ndiff (compose f g)) x)) + (lambda (f g x) ((v:fdiff (compose f g)) x)) + (lambda (f g x) ((v:rdiff (compose f g)) x)))) -(test-end "grad") +(test-end "autodiff") diff --git a/grad.scm b/autodiff.scm similarity index 72% rename from grad.scm rename to autodiff.scm index 9e4ec76..a6666ad 100644 --- a/grad.scm +++ b/autodiff.scm @@ -1,4 +1,4 @@ -(define-module (vouivre grad) +(define-module (vouivre autodiff) #:use-module (ice-9 receive) #:use-module (srfi srfi-1) #:use-module (srfi srfi-9) @@ -15,7 +15,8 @@ ewise1 ewise2 extend - grad + fdiff + rdiff make-batch make-internal maximum @@ -144,13 +145,20 @@ element-wise. All arrays must have the same dimension." ;;(define *atype* 'f32) (define *atype* #t) -(define *grad* (make-parameter #f)) +(define *differentiation-mode* (make-parameter #f)) +(define *n-y-dims* (make-parameter #f)) (define *j* (make-parameter #f)) (define-syntax-rule (w/j val body ...) (parameterize ((*j* val)) body ...)) +(define (wrap axis) + (lambda (x i) + (if (= i axis) + (make-internal x 'input) + x))) + (define (unwrap-fwd x) (if (internal? x) (internal-forward x) @@ -166,7 +174,18 @@ element-wise. All arrays must have the same dimension." '() (array-dimensions x))) -(define (mov dst-buf n-dst-dims generator naked-inputs data j) +(define (add dst-buf src-buf n-dims) + (do-times + n-dims + (lambda (i) + (array-set! + dst-buf + (+ (array-ref dst-buf i) + (array-ref src-buf i)) + i))) + dst-buf) + +(define (movg dst-buf n-dst-dims generator naked-inputs data j) (do-times n-dst-dims (lambda (i) @@ -176,7 +195,7 @@ element-wise. All arrays must have the same dimension." i))) dst-buf) -(define (add dst-buf n-dst-dims generator naked-inputs data j) +(define (addg dst-buf n-dst-dims generator naked-inputs data j) (do-times n-dst-dims (lambda (i) @@ -221,15 +240,15 @@ adding the result to the destination buffer." (array-set! dst-buf s i)))) dst-buf) -(define* (grad f #:optional (axis 0)) - (define (wrap x i) - (if (= i axis) - (make-internal x 'input) - x)) +(define (transpose-generator generator) + (lambda (xs i j . data) + (apply generator xs j i data))) + +(define* (fdiff f #:optional (axis 0)) (lambda xs - (parameterize (((@@ (vouivre grad) *grad*) #t) - ((@@ (vouivre grad) *promises*) (cons '() #f))) - (let* ((internal (apply f (map-indexed wrap xs))) + (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'fwd) + ((@@ (vouivre autodiff) *promises*) (cons '() #f))) + (let* ((internal (apply f (map-indexed (wrap axis) xs))) (fx (internal-forward internal)) (y (list-ref xs axis)) ; variable to differentiate w.r.t (pre-Jx (internal-jacobian internal)) @@ -289,6 +308,68 @@ adding the result to the destination buffer." (+ j (* n-y-dims i)))))))) a))))))) +(define* (rdiff f #:optional (axis 0)) + (lambda xs + (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'rev) + ((@@ (vouivre autodiff) *promises*) (cons '() #f))) + (let* ((internal (apply f (map-indexed (wrap axis) xs))) + (fx (internal-forward internal)) + (y (list-ref xs axis)) ; variable to differentiate w.r.t + (y-dims (dims-of y)) + (pre-Jx (internal-jacobian internal)) + (Jx (cond + ;; TODO: implement 'input case and test 'zero and 'input + ((eq? pre-Jx 'zero) + (lambda (i) + (lambda (j) + 0))) + ((eq? pre-Jx 'input) + (error "TBD.")) + (else + (let ((pre-Jx (pre-Jx #f))) + (lambda (i) + (let ((row-jac (pre-Jx i))) + (lambda (j) + (array-ref row-jac j))))))))) + (parameterize ((*n-y-dims* (apply * y-dims))) + (cond + ((and (number? fx) (number? y)) + ((Jx 0) 0)) + ((and (number? fx) (array? y)) + (let* ((a (apply make-array *unspecified* y-dims)) + (ac (array-contents a)) + (Jx (Jx 0))) + (do-times + (*n-y-dims*) + (lambda (j) + (array-set! ac (Jx j) + j))) + a)) + ((and (array? fx) (number? y)) + (let* ((fx-dims (array-dimensions fx)) + (a (apply make-array *unspecified* fx-dims)) + (ac (array-contents a))) + (do-times + (apply * fx-dims) + (lambda (i) + (array-set! ac ((Jx i) 0) + i))) + a)) + (else + (let* ((fx-dims (array-dimensions fx)) + (n-fx-dims (apply * fx-dims)) + (a (apply make-array *unspecified* (append fx-dims y-dims))) + (ac (array-contents a))) + (do-times + n-fx-dims + (lambda (i) + (let ((Jx (Jx i))) + (do-times + (*n-y-dims*) + (lambda (j) + (array-set! ac (Jx j) + (+ j (* (*n-y-dims*) i)))))))) + a)))))))) ;; In the comment that follows: ;; @@ -317,47 +398,126 @@ adding the result to the destination buffer." ;; NOTE: In cases where an argument isn't meant to be differentiable its ;; corresponding generator should be `#f'. (define (differentiable-wrapper generators proc* arg . more) + (define (precompute-data naked-args) + (if (procedure? proc*) + '() + (map (lambda (g) + (apply g naked-args)) + (cdr proc*)))) (let* ((args (cons arg more)) (proc (if (procedure? proc*) proc* (car proc*))) (naked-args (map unwrap-fwd args)) (out (apply proc naked-args))) - (ifn (*grad*) - out - (let* ((data (if (procedure? proc*) - '() - (map (lambda (g) - (apply g naked-args)) - (cdr proc*)))) - (n-out-dims (apply * (dims-of out))) - (buf (make-array *unspecified* n-out-dims))) - (make-internal - out - (fold - (lambda (generator arg prev) + (case (*differentiation-mode*) + ((#f) + out) + ((fwd) + (let* ((data (precompute-data naked-args)) + (n-out-dims (apply * (dims-of out))) + (buf (make-array *unspecified* n-out-dims))) + (make-internal + out + (fold + (lambda (generator arg prev) + (if (or (not (internal? arg)) + (eq? 'zero (internal-jacobian arg))) + prev + (let ((Jx (internal-jacobian arg)) + (n-fwd-dims (apply * (dims-of (unwrap-fwd arg))))) + (if (eq? Jx 'input) + (if (eq? prev 'zero) + (delay + (movg buf n-out-dims + generator naked-args data (*j*))) + (delay + (addg (force prev) n-out-dims + generator naked-args data (*j*)))) + (if (eq? prev 'zero) + (delay + (movc buf n-out-dims (force Jx) n-fwd-dims + generator naked-args data)) + (delay + (addc (force prev) n-out-dims + (force Jx) n-fwd-dims + generator naked-args data))))))) + 'zero generators args)))) + ((rev) + (let ((data (precompute-data naked-args)) + (n-out-dims (apply * (dims-of out)))) + (make-internal + out + (fold + (lambda (generator arg prev) + (let ((generator (transpose-generator generator))) (if (or (not (internal? arg)) (eq? 'zero (internal-jacobian arg))) prev - (let ((Jx (internal-jacobian arg)) - (n-fwd-dims (apply * (dims-of (unwrap-fwd arg))))) + (let* ((Jx (internal-jacobian arg)) + (n-fwd-dims (apply * (dims-of (unwrap-fwd arg))))) (if (eq? Jx 'input) (if (eq? prev 'zero) - (delay - (mov buf n-out-dims - generator naked-args data (*j*))) - (delay - (add (force prev) n-out-dims - generator naked-args data (*j*)))) + (lambda (buf?) + (let ((dst-buf (make-array *unspecified* + n-fwd-dims))) + (if buf? + (lambda (buf) + (movc dst-buf n-fwd-dims + buf n-out-dims + generator naked-args data)) + (lambda (i) + (movg dst-buf n-fwd-dims + generator naked-args data + i))))) + (lambda (buf?) + (let ((prev (prev buf?))) + (if buf? + (lambda (buf) + (addc (prev buf) n-fwd-dims + buf n-out-dims + generator naked-args data)) + (lambda (i) + (addg (prev i) n-fwd-dims + generator naked-args data + i)))))) (if (eq? prev 'zero) - (delay - (movc buf n-out-dims (force Jx) n-fwd-dims - generator naked-args data)) - (delay - (addc (force prev) n-out-dims - (force Jx) n-fwd-dims - generator naked-args data))))))) - 'zero generators args)))))) + (lambda (buf?) + (let ((Jx (Jx #t)) + (dst-buf (make-array *unspecified* + n-fwd-dims))) + (if buf? + (lambda (buf) + (Jx + (movc dst-buf n-fwd-dims buf + n-out-dims + generator naked-args data))) + (lambda (i) + (Jx + (movg dst-buf n-fwd-dims + generator naked-args data + i)))))) + (lambda (buf?) + (let ((prev (prev buf?)) + (Jx (Jx #t)) + (dst-buf (make-array *unspecified* + n-fwd-dims))) + (if buf? + (lambda (buf) + (add (prev buf) + (Jx + (movc dst-buf n-fwd-dims + buf n-out-dims + generator naked-args data)) + (*n-y-dims*))) + (lambda (i) + (add (prev i) + (Jx + (movg dst-buf n-fwd-dims + generator naked-args data + i)) + (*n-y-dims*)))))))))))) + 'zero generators args))))))) (define (ewise1 f) (lambda (xs i j) -- 2.39.2