From: admin Date: Thu, 23 Nov 2023 07:17:16 +0000 (+0900) Subject: Implement reverse mode automatic differentiation X-Git-Tag: v0.2.0~9 X-Git-Url: https://git.vouivredigital.com/?a=commitdiff_plain;h=ec5567a74aa49b25499cf1fd908d2ce51409aa28;p=vouivre.git Implement reverse mode automatic differentiation --- diff --git a/autodiff-tests.scm b/autodiff-tests.scm new file mode 100644 index 0000000..f1fc168 --- /dev/null +++ b/autodiff-tests.scm @@ -0,0 +1,351 @@ +(define-module (vouivre autodiff tests) + #:use-module ((vouivre autodiff) #:prefix v:) + #:use-module (ice-9 receive) + #:use-module (srfi srfi-1) + #:use-module (srfi srfi-64) + #:use-module (vouivre misc) + #:export + (apply-diff + a~ + const-generator + differentiable-func-generator + lambda-const-call + ndiff + n~ + random-array + random-array-shape + random-func1 + random-func2 + random-func2-rank&dims>0 + random-input + random-list-element + random-non-empty-array + random-shape + random-shared + random-shared-array-rank&dims>0 + random-shared-contractible + with-generators + ~)) + +(define f1s (list v:abs v:cos v:exp v:identity v:sin)) +(define f2s (list v:+ v:- v:* v:max v:min)) + +(define (with-generators% generators equal proc1 proc2 . more) + "Check that all procedures return the same value according to `equal' when +evaluated on arguments produced by the generators (the number of generators +being the number of arguments to each procedure." + (let ((times 100) + (procs (cons proc1 (cons proc2 more)))) + (call/cc + (lambda (break) + (do ((i 0 (1+ i))) + ((= i times) #t) + (let ((zs (map-in-order (lambda (g) (g)) generators))) + (with-exception-handler + (lambda (e) + (break #f zs)) + (lambda () + (let* ((rs (map (lambda (f) (apply f zs)) procs)) + (head (car rs))) + (unless (every (lambda (x) (equal x head)) + (cdr rs)) + (break #f zs rs)))) + #:unwind? #t))))))) + +(define-syntax-rule (with-generators (g1 g2 ...) equal expected given more ...) + (with-generators% (list g1 g2 ...) equal expected given more ...)) + +(define (lambda-const-call f . consts) + (lambda _ + (apply f consts))) + +(define* (random-array-shape + #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5)) + (list-tabulate (+ min-rank (random (- max-rank min-rank))) + (lambda _ (+ min-dim (random (- max-dim min-dim)))))) + +(define (random-shape) + (if (= 0 (random 2)) + 0 + (random-array-shape))) + +(define* (random-array #:optional shape) + (apply produce-typed-array + (lambda _ (random:uniform)) + v:*atype* (or shape (random-array-shape)))) + +(define (random-non-empty-array) + "Random array of at least one element." + (random-array (random-array-shape 0 5 1 5))) + +(define* (random-input #:optional shape) + (let ((shape (or shape (random-shape)))) + (if (eq? 0 shape) + (random:uniform) + (random-array shape)))) + +(define (random-shared) + (let ((shape (random-shape))) + (values + (lambda () + (random-input shape)) + (lambda () + (let ((x (random-input + (random-list-element + (list 0 (if (list? shape) + shape + (random-shape))))))) + (set! shape (random-shape)) + x))))) + +(define (random-shared-array-rank&dims>0) + (let ((shape (random-array-shape 1 5 1 5))) + (values + (lambda () + (random-array shape)) + (lambda () + (let ((x (random-array shape))) + (set! shape (random-array-shape 1 5 1 5)) + x))))) + +(define (random-list-element lst) + (list-ref lst (random (length lst)))) + +(define (const-generator generator) + (lambda () + generator)) + +(define (differentiable-func-generator lst . input-generators) + (lambda () + (random-list-element + (cons + (apply + lambda-const-call + (random-list-element lst) + (map (lambda (g) (g)) + input-generators)) + lst)))) + +(define random-func1 + (differentiable-func-generator f1s random-input)) +(define random-func2 + (receive (gx gy) (random-shared) + (differentiable-func-generator f2s gx gy))) + +(define* (n~ x y #:optional (error 1e-4)) + (and + (>= y (- x error)) + (<= y (+ x error)))) + +(define* (a~ x y #:optional (error 1e-4)) + (and + (equal? (array-dimensions x) + (array-dimensions y)) + (call/cc + (lambda (break) + (array-for-each + (lambda (x y) + (unless (~ x y error) + (break #f))) + x y) + #t)))) + +(define* (~ x y #:optional (error 1e-4)) + (cond + ((and (number? x) (number? y)) + (n~ x y error)) + ((and (array? x) (array? y)) + (a~ x y error)) + (else #f))) + +(define* (ndiff f #:optional (axis 0) (step 1e-6)) + "Differentiation using a numerical centered difference approximation." + (define (axis-add xs dh . indices) + "Add `dh' to the number or array at the given `axis' of `xs', +and, when it's an array, at the given index." + (map-indexed + (lambda (x i) + (ifn (= i axis) + x + (if (number? x) + (+ x dh) + (array-map-indexed + (lambda (x . indices_) + (ifn (equal? indices indices_) + x + (+ x dh))) + x)))) + xs)) + (lambda xs + ;; We need the output shape and the input shape along the + ;; differentiated axis. + (let ((fxs (apply f xs)) + (x (list-ref xs axis))) + (cond + ((and (number? fxs) + (number? x)) + (/ (- (apply f (axis-add xs step)) + (apply f (axis-add xs (- step)))) + (* 2 step))) + ((and (number? fxs) + (array? x)) + (apply + produce-typed-array + (lambda indices + (/ (- (apply f (apply axis-add xs step indices)) + (apply f (apply axis-add xs (- step) indices))) + (* 2 step))) + v:*atype* + (array-dimensions x))) + ((and (array? fxs) + (number? x)) + ((v:extend /) + ((v:extend -) + (apply f (axis-add xs step)) + (apply f (axis-add xs (- step)))) + (* 2 step))) + ((and (array? fxs) + (array? x)) + (let ((a (apply + make-typed-array v:*atype* *unspecified* + (append (array-dimensions fxs) + (array-dimensions x))))) + (for-indices-in-range + (lambda indices-in + (let ((dfxs ((v:extend /) + ((v:extend -) + (apply f (apply axis-add xs step indices-in)) + (apply f (apply axis-add xs (- step) indices-in))) + (* 2 step)))) + (for-indices-in-range + (lambda indices-out + (apply + array-set! + a + (apply array-ref dfxs indices-out) + (append indices-out indices-in))) + (list-zeros (array-rank fxs)) + (array-dimensions fxs)))) + (list-zeros (array-rank x)) + (array-dimensions x)) + a)))))) + +(define* (apply-diff differentiator #:optional (axis 0)) + "Apply a differentiator (`ndiff', `fdiff', `rdiff') to a function and its +arguments (this is a convenience function)." + (lambda (f . args) + (apply (differentiator f axis) args))) + +(test-begin "autodiff") + +;; not differentiating +(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity)) +(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp)) +(test-assert + (receive (gx gy) (random-shared) + (with-generators (gx gy) ~ (v:extend *) v:*))) + +;; differentiation in one variable +(test-assert + (with-generators + (random-func1 random-input) + ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff))) + +;; `v:mean' only takes non-empty arrays so we treat it separately +(test-assert + (with-generators + ((differentiable-func-generator (list v:mean) random-non-empty-array) + random-non-empty-array) + ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff))) + +;; differentiation in two variables +(test-assert + (receive (gx gy) (random-shared) + (with-generators + (random-func2 gx gy) + ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0)))) +(test-assert + (receive (gx gy) (random-shared) + (with-generators + (random-func2 gx gy) + ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1)))) + +;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it +;; separately +(define random-func2-rank&dims>0 + (receive (gx gy) (random-shared-array-rank&dims>0) + (differentiable-func-generator f2s gx gy))) +(test-assert + (receive (gx gy) (random-shared-array-rank&dims>0) + (with-generators + ((const-generator v:amap2) random-func2-rank&dims>0 gx gy) + ;; NOTE: for `v:amap2' the differentiable axes are 1 and 2. + ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1)))) +(test-assert + (receive (gx gy) (random-shared-array-rank&dims>0) + (with-generators + ((const-generator v:amap2) random-func2-rank&dims>0 gx gy) + ~ (apply-diff ndiff 2) (apply-diff v:fdiff 2) (apply-diff v:rdiff 2)))) +(let* ((z #(1 2 3)) + (f (lambda (a) + (v:amap2 (lambda (x y) + (v:* a a)) + #(10 20 30) + #(40 50 60)))) + (e ((ndiff f) z))) + (test-assert (~ e ((v:fdiff f) z))) + (test-assert (~ e ((v:rdiff f) z)))) + +;; `v:adot' +(define (random-shared-contractible) + "Returns three generators: the first two generate arrays that are contractible +according to the number generated by the third one." + (let* ((n (random 5)) + (sa (random-array-shape n)) + (sb (append (reverse (take (reverse sa) + n)) + (random-array-shape 0 (- 5 n))))) + (values + (lambda () + (random-array sa)) + (lambda () + (random-array sb)) + (lambda () + (let ((tmp n)) + (set! n (random 5)) + (set! sa (random-array-shape n)) + (set! sb (append (reverse (take (reverse sa) + n)) + (random-array-shape 0 (- 5 n)))) + tmp))))) +(test-assert + (receive (gx gy gz) (random-shared-contractible) + (with-generators + ((const-generator v:adot) gx gy gz) + ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0)))) +(test-assert + (receive (gx gy gz) (random-shared-contractible) + (with-generators + ((const-generator v:adot) gx gy gz) + ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1)))) + +;; let binding re-entry +(test-assert + (with-generators + ((const-generator + (lambda (x) + (let ((c (v:maximum x))) + (v:+ c (v:- x c))))) + random-non-empty-array) + ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff))) + +;; chain rule +(test-assert + (with-generators + (random-func1 random-func1 random-input) + ~ + (lambda (f g x) ((ndiff (compose f g)) x)) + (lambda (f g x) ((v:fdiff (compose f g)) x)) + (lambda (f g x) ((v:rdiff (compose f g)) x)))) + +(test-end "autodiff") diff --git a/autodiff.scm b/autodiff.scm new file mode 100644 index 0000000..a6666ad --- /dev/null +++ b/autodiff.scm @@ -0,0 +1,859 @@ +(define-module (vouivre autodiff) + #:use-module (ice-9 receive) + #:use-module (srfi srfi-1) + #:use-module (srfi srfi-9) + #:use-module (vouivre misc) + #:use-module (vouivre promises) + #:export + (*atype* + adot + amap2 + contract-arrays + differentiable-wrapper + dot + do-times + ewise1 + ewise2 + extend + fdiff + rdiff + make-batch + make-internal + maximum + mean + rank-of + sum) + #:replace + ((i:sqrt . sqrt) + (i:exp . exp) + (i:expt . expt) + (i:log . log) + (i:sin . sin) + (i:cos . cos) + (i:tan . tan) + (i:+ . +) + (i:- . -) + (i:* . *) + (i:/ . /) + (i:max . max) + (i:min . min) + (i:abs . abs) + (i:identity . identity) + (i:array-ref . array-ref) + (i:array-cell-ref . array-cell-ref)) + #:re-export + (fold + reduce)) + +;;;; array utilities + +(define (rel->abs indices dimensions) + (let rec ((s 0) + (p 1) + (is (reverse indices)) + (ds (reverse dimensions))) + (if (null? is) + s + (rec (+ s (* p (car is))) + (* p (car ds)) + (cdr is) + (cdr ds))))) + +(define(do-times n proc) + (let rec ((i 0)) + (unless (= i n) + (proc i) + (rec (1+ i))))) + +(define (contract-arrays a b n) + (let* ((dims-a (array-dimensions a)) + (dims-b (array-dimensions b)) + (free-dims-a (take dims-a (- (array-rank a) n))) + (free-dims-b (drop dims-b n)) + (bound-dims (take dims-b n)) + (n-free-dims-a (apply * free-dims-a)) + (n-free-dims-b (apply * free-dims-b)) + (n-bound-dims (apply * bound-dims)) + (s 0) + (r (apply make-typed-array *atype* *unspecified* (append free-dims-a + free-dims-b))) + (ac (array-contents a)) + (bc (array-contents b)) + (rc (array-contents r))) + (do-times + n-free-dims-a + (lambda (i) + (let ((i-k (* n-bound-dims i)) + (i-j (* n-free-dims-b i))) + (do-times + n-free-dims-b + (lambda (j) + (set! s 0) + (do-times + n-bound-dims + (lambda (k) + (set! s (+ s (* (array-ref ac (+ i-k k)) + (array-ref bc (+ (* n-free-dims-b k) j))))))) + (array-set! rc s (+ i-j j))))))) + r)) + +;;;; utilities that work on both numbers and arrays + +(define (extend f) + "Extend a function of one or more scalars to apply to numbers/arrays +element-wise. All arrays must have the same dimension." + (define (apply-elemwise f indices args) + (apply f (map (lambda (x) + (if (number? x) + x + (apply array-ref x indices))) + args))) + (lambda xs + (if-let (x (find array? xs)) + (apply + produce-typed-array + (lambda is + (apply-elemwise f is xs)) + *atype* + (array-dimensions x)) + (apply f xs)))) + +(define (dot x y n) + (cond + ((and (number? x) (number? y)) + (* x y)) + ((and (array? x) (array? y)) + (contract-arrays x y n)) + ((and (array? x) (number? y)) + ((extend *) x y)) + ((and (number? x) (array? y)) + ((extend *) x y)) + (else (error "can't dot because of invalid types or ranks" x y n)))) + +(define (rank-of x) + (if (number? x) + 0 + (array-rank x))) + +;;;; differentiation + +(define-record-type internal + (make-internal forward jacobian) + internal? + (forward internal-forward) + (jacobian internal-jacobian)) + +;;(define *atype* 'f32) +(define *atype* #t) +(define *differentiation-mode* (make-parameter #f)) +(define *n-y-dims* (make-parameter #f)) +(define *j* (make-parameter #f)) + +(define-syntax-rule (w/j val body ...) + (parameterize ((*j* val)) + body ...)) + +(define (wrap axis) + (lambda (x i) + (if (= i axis) + (make-internal x 'input) + x))) + +(define (unwrap-fwd x) + (if (internal? x) + (internal-forward x) + x)) + +(define (unwrap-jac x) + (if (internal? x) + (internal-jacobian x) + x)) + +(define (dims-of x) + (if (number? x) + '() + (array-dimensions x))) + +(define (add dst-buf src-buf n-dims) + (do-times + n-dims + (lambda (i) + (array-set! + dst-buf + (+ (array-ref dst-buf i) + (array-ref src-buf i)) + i))) + dst-buf) + +(define (movg dst-buf n-dst-dims generator naked-inputs data j) + (do-times + n-dst-dims + (lambda (i) + (array-set! + dst-buf + (apply generator naked-inputs i j data) + i))) + dst-buf) + +(define (addg dst-buf n-dst-dims generator naked-inputs data j) + (do-times + n-dst-dims + (lambda (i) + (array-set! + dst-buf + (+ (array-ref dst-buf i) + (apply generator naked-inputs i j data)) + i))) + dst-buf) + +(define (movc dst-buf n-dst-dims src-buf n-src-dims + generator naked-inputs data) + "Contract the Jacobian column produced by the generator with the source buffer +storing the result in the destination buffer." + (let ((s 0)) + (do-times + n-dst-dims + (lambda (i) + (set! s 0) + (do-times + n-src-dims + (lambda (k) + (set! s (+ s (* (apply generator naked-inputs i k data) + (array-ref src-buf k)))))) + (array-set! dst-buf s i)))) + dst-buf) + +(define (addc dst-buf n-dst-dims src-buf n-src-dims + generator naked-inputs data) + "Contract the Jacobian column produced by the generator with the source buffer +adding the result to the destination buffer." + (let ((s 0)) + (do-times + n-dst-dims + (lambda (i) + (set! s (array-ref dst-buf i)) + (do-times + n-src-dims + (lambda (k) + (set! s (+ s (* (apply generator naked-inputs i k data) + (array-ref src-buf k)))))) + (array-set! dst-buf s i)))) + dst-buf) + +(define (transpose-generator generator) + (lambda (xs i j . data) + (apply generator xs j i data))) + +(define* (fdiff f #:optional (axis 0)) + (lambda xs + (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'fwd) + ((@@ (vouivre autodiff) *promises*) (cons '() #f))) + (let* ((internal (apply f (map-indexed (wrap axis) xs))) + (fx (internal-forward internal)) + (y (list-ref xs axis)) ; variable to differentiate w.r.t + (pre-Jx (internal-jacobian internal)) + (Jx (cond + ;; TODO: implement 'input case and test 'zero and 'input + ((eq? pre-Jx 'zero) + (lambda (j) + (lambda (i) + 0))) + ((eq? pre-Jx 'input) + (error "TBD.")) + (else + (lambda (j) + (reset-promises (car (*promises*))) + (let ((column-jac (w/j j (force pre-Jx)))) + (lambda (i) + (array-ref column-jac i)))))))) + (cond + ((and (number? fx) (number? y)) + ((Jx 0) 0)) + ((and (number? fx) (array? y)) + (let* ((y-dims (array-dimensions y)) + (a (apply make-array *unspecified* y-dims)) + (ac (array-contents a))) + (do-times + (apply * y-dims) + (lambda (j) + (array-set! ac ((Jx j) 0) + j))) + a)) + ((and (array? fx) (number? y)) + (let* ((fx-dims (array-dimensions fx)) + (a (apply make-array *unspecified* fx-dims)) + (ac (array-contents a)) + (Jx (Jx 0))) + (do-times + (apply * fx-dims) + (lambda (i) + (array-set! ac (Jx i) + i))) + a)) + (else + (let* ((fx-dims (array-dimensions fx)) + (y-dims (array-dimensions y)) + (n-fx-dims (apply * fx-dims)) + (n-y-dims (apply * y-dims)) + (a (apply make-array *unspecified* (append fx-dims y-dims))) + (ac (array-contents a))) + (do-times + n-y-dims + (lambda (j) + (let ((Jx (Jx j))) + (do-times + n-fx-dims + (lambda (i) + (array-set! ac (Jx i) + (+ j (* n-y-dims i)))))))) + a))))))) + +(define* (rdiff f #:optional (axis 0)) + (lambda xs + (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'rev) + ((@@ (vouivre autodiff) *promises*) (cons '() #f))) + (let* ((internal (apply f (map-indexed (wrap axis) xs))) + (fx (internal-forward internal)) + (y (list-ref xs axis)) ; variable to differentiate w.r.t + (y-dims (dims-of y)) + (pre-Jx (internal-jacobian internal)) + (Jx (cond + ;; TODO: implement 'input case and test 'zero and 'input + ((eq? pre-Jx 'zero) + (lambda (i) + (lambda (j) + 0))) + ((eq? pre-Jx 'input) + (error "TBD.")) + (else + (let ((pre-Jx (pre-Jx #f))) + (lambda (i) + (let ((row-jac (pre-Jx i))) + (lambda (j) + (array-ref row-jac j))))))))) + (parameterize ((*n-y-dims* (apply * y-dims))) + (cond + ((and (number? fx) (number? y)) + ((Jx 0) 0)) + ((and (number? fx) (array? y)) + (let* ((a (apply make-array *unspecified* y-dims)) + (ac (array-contents a)) + (Jx (Jx 0))) + (do-times + (*n-y-dims*) + (lambda (j) + (array-set! ac (Jx j) + j))) + a)) + ((and (array? fx) (number? y)) + (let* ((fx-dims (array-dimensions fx)) + (a (apply make-array *unspecified* fx-dims)) + (ac (array-contents a))) + (do-times + (apply * fx-dims) + (lambda (i) + (array-set! ac ((Jx i) 0) + i))) + a)) + (else + (let* ((fx-dims (array-dimensions fx)) + (n-fx-dims (apply * fx-dims)) + (a (apply make-array *unspecified* (append fx-dims y-dims))) + (ac (array-contents a))) + (do-times + n-fx-dims + (lambda (i) + (let ((Jx (Jx i))) + (do-times + (*n-y-dims*) + (lambda (j) + (array-set! ac (Jx j) + (+ j (* (*n-y-dims*) i)))))))) + a)))))))) + +;; In the comment that follows: +;; +;; `n' is the number of arguments to `proc'. +;; `generators is not a `Vec' but a `List' we only use the former to illustrate +;; its length. +;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension. +;; `I' is the type of multi-indices indexing the output of `function'. +;; `J' is the type of multi-indices indexing the input array being differentiated. +;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J'). +;; `Array I' is the type of arrays indexed by multi-indices of `I'. +;; `[X]' means that `X' is boxed in an internal as when returned by +;; `differentiable-wrapper' with the array being `X' and the promise that +;; given a |J| we will get the change of `X' with a change of the +;; the differentiated argument at multi-index `J'. +;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number)) +;; (→ X1 ... Xn (Array I))* +;; [X1] ... [Xn] +;; (Internal (Array I) (Promise |J| (Array |I|))))) +;; +;; (*) We extend this definition to allow `proc' to be a list of procedures +;; the head of which is as described above and the remaining elements +;; are procedures of the same arguments but returning values that are +;; then fed as extra data to the generators. +;; +;; NOTE: In cases where an argument isn't meant to be differentiable its +;; corresponding generator should be `#f'. +(define (differentiable-wrapper generators proc* arg . more) + (define (precompute-data naked-args) + (if (procedure? proc*) + '() + (map (lambda (g) + (apply g naked-args)) + (cdr proc*)))) + (let* ((args (cons arg more)) + (proc (if (procedure? proc*) + proc* + (car proc*))) + (naked-args (map unwrap-fwd args)) + (out (apply proc naked-args))) + (case (*differentiation-mode*) + ((#f) + out) + ((fwd) + (let* ((data (precompute-data naked-args)) + (n-out-dims (apply * (dims-of out))) + (buf (make-array *unspecified* n-out-dims))) + (make-internal + out + (fold + (lambda (generator arg prev) + (if (or (not (internal? arg)) + (eq? 'zero (internal-jacobian arg))) + prev + (let ((Jx (internal-jacobian arg)) + (n-fwd-dims (apply * (dims-of (unwrap-fwd arg))))) + (if (eq? Jx 'input) + (if (eq? prev 'zero) + (delay + (movg buf n-out-dims + generator naked-args data (*j*))) + (delay + (addg (force prev) n-out-dims + generator naked-args data (*j*)))) + (if (eq? prev 'zero) + (delay + (movc buf n-out-dims (force Jx) n-fwd-dims + generator naked-args data)) + (delay + (addc (force prev) n-out-dims + (force Jx) n-fwd-dims + generator naked-args data))))))) + 'zero generators args)))) + ((rev) + (let ((data (precompute-data naked-args)) + (n-out-dims (apply * (dims-of out)))) + (make-internal + out + (fold + (lambda (generator arg prev) + (let ((generator (transpose-generator generator))) + (if (or (not (internal? arg)) + (eq? 'zero (internal-jacobian arg))) + prev + (let* ((Jx (internal-jacobian arg)) + (n-fwd-dims (apply * (dims-of (unwrap-fwd arg))))) + (if (eq? Jx 'input) + (if (eq? prev 'zero) + (lambda (buf?) + (let ((dst-buf (make-array *unspecified* + n-fwd-dims))) + (if buf? + (lambda (buf) + (movc dst-buf n-fwd-dims + buf n-out-dims + generator naked-args data)) + (lambda (i) + (movg dst-buf n-fwd-dims + generator naked-args data + i))))) + (lambda (buf?) + (let ((prev (prev buf?))) + (if buf? + (lambda (buf) + (addc (prev buf) n-fwd-dims + buf n-out-dims + generator naked-args data)) + (lambda (i) + (addg (prev i) n-fwd-dims + generator naked-args data + i)))))) + (if (eq? prev 'zero) + (lambda (buf?) + (let ((Jx (Jx #t)) + (dst-buf (make-array *unspecified* + n-fwd-dims))) + (if buf? + (lambda (buf) + (Jx + (movc dst-buf n-fwd-dims buf + n-out-dims + generator naked-args data))) + (lambda (i) + (Jx + (movg dst-buf n-fwd-dims + generator naked-args data + i)))))) + (lambda (buf?) + (let ((prev (prev buf?)) + (Jx (Jx #t)) + (dst-buf (make-array *unspecified* + n-fwd-dims))) + (if buf? + (lambda (buf) + (add (prev buf) + (Jx + (movc dst-buf n-fwd-dims + buf n-out-dims + generator naked-args data)) + (*n-y-dims*))) + (lambda (i) + (add (prev i) + (Jx + (movg dst-buf n-fwd-dims + generator naked-args data + i)) + (*n-y-dims*)))))))))))) + 'zero generators args))))))) + +(define (ewise1 f) + (lambda (xs i j) + (let ((x (car xs))) + (if (number? x) + (f x) + (ifn (= i j) + 0 + (f (array-ref (array-contents x) + j))))))) + +(define (ewise2 proc axis) + (lambda (xs i j) + (let ((x (car xs)) + (y (cadr xs))) + (cond + ((and (number? x) (number? y)) + (proc x y)) + ((and (number? x) (array? y)) + (if (= axis 0) + (proc x (array-ref (array-contents y) + i)) + (ifn (= i j) + 0 + (proc x (array-ref (array-contents y) + j))))) + ((and (array? x) (number? y)) + (if (= axis 1) + (proc (array-ref (array-contents x) + i) + y) + (ifn (= i j) + 0 + (proc (array-ref (array-contents x) + j) + y)))) + (else + (ifn (= i j) + 0 + (proc (array-ref (array-contents x) + j) + (array-ref (array-contents y) + j)))))))) + +(define (i:identity x) + "Differentiable identity." + (differentiable-wrapper + (list (ewise1 (lambda _ 1))) + identity + x)) + +(define (i:sqrt x) + "Differentiable square root." + (differentiable-wrapper + (list (ewise1 (lambda (x) (/ 1 2 (sqrt x))))) + (extend sqrt) + x)) + +(define (i:exp x) + "Differentiable exponential." + (differentiable-wrapper + (list (ewise1 exp)) + (extend exp) + x)) + +(define (i:expt x y) + "Differentiable power." + (differentiable-wrapper + (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0) + (ewise2 (lambda (x y) (* (expt x y) (log x))) 1)) + (extend expt) + x y)) + +(define (i:log x) + "Differentiable logarithm." + (differentiable-wrapper + (list (ewise1 (lambda (x) (/ x)))) + (extend log) + x)) + +(define (i:sin x) + "Differentiable sine." + (differentiable-wrapper + (list (ewise1 cos)) + (extend sin) + x)) + +(define (i:cos x) + "Differentiable cosine." + (differentiable-wrapper + (list (ewise1 (lambda (x) (- (sin x))))) + (extend cos) + x)) + +(define (i:tan x) + "Differentiable tangent." + (differentiable-wrapper + (list (ewise1 (lambda (x) (/ (expt (cos x) 2))))) + (extend tan) + x)) + +(define (i:+ x y) + "Differentiable element-wise addition." + (differentiable-wrapper + (list + (ewise2 (lambda _ +1) 0) + (ewise2 (lambda _ +1) 1)) + (extend +) + x y)) + +(define (i:- x y) + "Differentiable element-wise subtraction." + (differentiable-wrapper + (list + (ewise2 (lambda _ +1) 0) + (ewise2 (lambda _ -1) 1)) + (extend -) + x y)) + +(define (i:* x y) + "Differentiable element-wise multiplication." + (differentiable-wrapper + (list + (ewise2 (lambda (x y) y) 0) + (ewise2 (lambda (x y) x) 1)) + (extend *) + x y)) + +(define (i:/ x y) + "Differentiable element-wise division." + (differentiable-wrapper + (list + (ewise2 (lambda (x y) (/ y)) 0) + (ewise2 (lambda (x y) (- (/ x y y))) 1)) + (extend /) + x y)) + +(define (i:max x y) + "Differentiable element-wise maximum." + (define (dmax x y) + (cond + ((> x y) + 1) + ((= x y) + 1/2) + (else + 0))) + (differentiable-wrapper + (list + (ewise2 dmax 0) + (ewise2 (flip dmax) 1)) + (extend max) + x y)) + +(define (i:min x y) + "Differentiable element-wise minimum." + (define (dmin x y) + (cond + ((< x y) + 1) + ((= x y) + 1/2) + (else + 0))) + (differentiable-wrapper + (list + (ewise2 dmin 0) + (ewise2 (flip dmin) 1)) + (extend min) + x y)) + +(define (i:abs x) + "Differentiable absolute." + (differentiable-wrapper + (list (ewise1 (lambda (x) + (cond ((> x 0) + +1) + ((= x 0) + 1/2) + ((< x 0) + -1))))) + (extend abs) + x)) + +(define (mean x) + "Differentiable mean on arrays." + (differentiable-wrapper + (list + (lambda (xs i j one-over-n) + one-over-n)) + (let ((n 0)) + (list + (lambda (x) + (let ((sum 0)) + (array-for-each + (lambda (x) + (set! sum (+ sum x)) + (set! n (1+ n))) + x) + (/ sum n))) + (lambda _ (/ n)))) + x)) + +(define (i:array-ref x . indices) + "Differentiable array-ref w.r.t `x'." + (apply + differentiable-wrapper + (cons + (lambda (xs i j abs-index) + (if (= j abs-index) + 1 + 0)) + (map not indices)) + (list + array-ref + (lambda (x . indices) + (rel->abs indices (array-dimensions x)))) + x indices)) + +(define (i:array-cell-ref x . indices) + (apply + differentiable-wrapper + (cons + (lambda (xs i j abs-index n-rst-dims) + (receive (j-ref j-rst) (euclidean/ j n-rst-dims) + (if (and (= j-ref abs-index) + (= j-rst i)) + 1 + 0))) + (map not indices)) + (list + array-cell-ref + (lambda (x . indices) + (rel->abs indices (take (array-dimensions x) + (length indices)))) + (lambda (x . indices) + (apply * (drop (array-dimensions x) + (length indices))))) + x indices)) + +(define (make-batch elem . more) + (let ((batch-size (1+ (length more)))) + (apply + differentiable-wrapper + (list-tabulate + batch-size + (lambda (b) + (lambda (xs i j n-rest-dims) + (receive (i-batch i-rest) (euclidean/ i n-rest-dims) + (if (and (= i-batch b) + (= i-rest j)) + 1 + 0 + ))))) + (list + (lambda (elem . more) + (let ((a (apply make-typed-array *atype* *unspecified* batch-size + (dims-of elem)))) + (for-each + (lambda (x b) + (array-cell-set! a x b)) + (cons elem more) + (list-tabulate batch-size identity)) + a)) + (lambda (elem . more) + (apply * (dims-of elem)))) + elem more))) + +(define (maximum x) + "Differentiable maximum on arrays." + (differentiable-wrapper + (list + (lambda (xs i j max-index) + (if (= j max-index) + 1 + 0))) + (let ((max-index 'TBD)) + (list + (lambda (x) + (let ((m (- (inf))) + (i 0)) + (array-for-each + (lambda (x) + (when (< m x) + (set! m x) + (set! max-index i)) + (set! i (1+ i))) + x) + m)) + (lambda _ max-index))) + x)) + +(define (sum x) + "Differentiable sum on arrays." + (differentiable-wrapper + (list (lambda _ 1)) + (lambda (x) + (let ((sum 0)) + (array-for-each + (lambda (x) + (set! sum (+ sum x))) + x) + sum)) + x)) + +(define (adot x y n) + (differentiable-wrapper + (list + (lambda (xs i j n-free-dims-y n-bound-dims) + (receive (i-x i-y) (euclidean/ i n-free-dims-y) + (receive (j-free j-bound) (euclidean/ j n-bound-dims) + (ifn (= i-x j-free) + 0 + (array-ref (array-contents (cadr xs)) + (+ i-y (* n-free-dims-y j-bound))))))) + (lambda (xs i j n-free-dims-y n-bound-dims) + (receive (i-x i-y) (euclidean/ i n-free-dims-y) + (receive (j-bound j-free) (euclidean/ j n-free-dims-y) + (ifn (= i-y j-free) + 0 + (array-ref (array-contents (car xs)) + (+ j-bound (* n-bound-dims i-x))))))) + #f) + (list + contract-arrays + (lambda (x y n) + (apply * (drop (array-dimensions y) + n))) + (lambda (x y n) + (apply * (take (array-dimensions y) + n)))) + x y n)) + +(define (amap2 f x y) + (apply make-batch + (list-tabulate (car (dims-of (unwrap-fwd x))) + (lambda (b) + (f (i:array-cell-ref x b) + (i:array-cell-ref y b)))))) diff --git a/grad-tests.scm b/grad-tests.scm deleted file mode 100644 index bdf8fd4..0000000 --- a/grad-tests.scm +++ /dev/null @@ -1,347 +0,0 @@ -(define-module (vouivre grad tests) - #:use-module ((vouivre grad) #:prefix v:) - #:use-module (ice-9 receive) - #:use-module (srfi srfi-1) - #:use-module (srfi srfi-64) - #:use-module (vouivre misc) - #:export - (apply-grad - apply-grad-amap2 - a~ - differentiable-func-generator - lambda-const-call - ngrad - n~ - random-array - random-array-shape - random-func1 - random-func2 - random-func2-rank&dims>0 - random-input - random-list-element - random-non-empty-array - random-shape - random-shared - random-shared-array-rank&dims>0 - random-shared-contractible - with-generators - ~)) - -(define f1s (list v:exp v:identity)) -(define f2s (list v:+ v:- v:* v:max v:min)) - -(define-syntax-rule (with-generators (g1 g2 ...) equal expected given) - (let ((times 100) - (fx expected) - (fy given) - (generators (list g1 g2 ...))) - (call/cc - (lambda (break) - (do ((i 0 (1+ i))) - ((= i times) #t) - (let ((zs (map-in-order (lambda (g) (g)) generators))) - (with-exception-handler - (lambda (e) - (break #f zs)) - (lambda () - (let ((r1 (apply fx zs)) - (r2 (apply fy zs))) - (unless (equal r1 r2) - (break #f zs r1 r2)))) - #:unwind? #t))))))) - -(define (lambda-const-call f . consts) - (lambda _ - (apply f consts))) - -(define* (random-array-shape - #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5)) - (list-tabulate (+ min-rank (random (- max-rank min-rank))) - (lambda _ (+ min-dim (random (- max-dim min-dim)))))) - -(define (random-shape) - (if (= 0 (random 2)) - 0 - (random-array-shape))) - -(define* (random-array #:optional shape) - (apply produce-typed-array - (lambda _ (random:uniform)) - v:*atype* (or shape (random-array-shape)))) - -(define (random-non-empty-array) - "Random array of at least one element." - (random-array (random-array-shape 0 5 1 5))) - -(define* (random-input #:optional shape) - (let ((shape (or shape (random-shape)))) - (if (eq? 0 shape) - (random:uniform) - (random-array shape)))) - -(define (random-shared) - (let ((shape (random-shape))) - (values - (lambda () - (random-input shape)) - (lambda () - (let ((x (random-input - (random-list-element - (list 0 (if (list? shape) - shape - (random-shape))))))) - (set! shape (random-shape)) - x))))) - -(define (random-shared-array-rank&dims>0) - (let ((shape (random-array-shape 1 5 1 5))) - (values - (lambda () - (random-array shape)) - (lambda () - (let ((x (random-array shape))) - (set! shape (random-array-shape 1 5 1 5)) - x))))) - -(define (random-list-element lst) - (list-ref lst (random (length lst)))) - -(define (differentiable-func-generator lst . input-generators) - (lambda () - (random-list-element - (cons - (apply - lambda-const-call - (random-list-element lst) - (map (lambda (g) (g)) - input-generators)) - lst)))) - -(define random-func1 - (differentiable-func-generator f1s random-input)) -(define random-func2 - (receive (gx gy) (random-shared) - (differentiable-func-generator f2s gx gy))) - -(define* (n~ x y #:optional (error 1e-4)) - (and - (>= y (- x error)) - (<= y (+ x error)))) - -(define* (a~ x y #:optional (error 1e-4)) - (and - (equal? (array-dimensions x) - (array-dimensions y)) - (call/cc - (lambda (break) - (array-for-each - (lambda (x y) - (unless (~ x y error) - (break #f))) - x y) - #t)))) - -(define* (~ x y #:optional (error 1e-4)) - (cond - ((and (number? x) (number? y)) - (n~ x y error)) - ((and (array? x) (array? y)) - (a~ x y error)) - (else #f))) - -(define* (ngrad f #:optional (axis 0) (step 1e-6)) - "Gradient using a numerical centered difference approximation." - (define (axis-add xs dh . indices) - "Add `dh' to the number or array at the given `axis' of `xs', -and, when it's an array, at the given index." - (map-indexed - (lambda (x i) - (ifn (= i axis) - x - (if (number? x) - (+ x dh) - (array-map-indexed - (lambda (x . indices_) - (ifn (equal? indices indices_) - x - (+ x dh))) - x)))) - xs)) - (lambda xs - ;; We need the output shape and the input shape along the - ;; differentiated axis. - (let ((fxs (apply f xs)) - (x (list-ref xs axis))) - (cond - ((and (number? fxs) - (number? x)) - (/ (- (apply f (axis-add xs step)) - (apply f (axis-add xs (- step)))) - (* 2 step))) - ((and (number? fxs) - (array? x)) - (apply - produce-typed-array - (lambda indices - (/ (- (apply f (apply axis-add xs step indices)) - (apply f (apply axis-add xs (- step) indices))) - (* 2 step))) - v:*atype* - (array-dimensions x))) - ((and (array? fxs) - (number? x)) - ((v:extend /) - ((v:extend -) - (apply f (axis-add xs step)) - (apply f (axis-add xs (- step)))) - (* 2 step))) - ((and (array? fxs) - (array? x)) - (let ((a (apply - make-typed-array v:*atype* *unspecified* - (append (array-dimensions fxs) - (array-dimensions x))))) - (for-indices-in-range - (lambda indices-in - (let ((dfxs ((v:extend /) - ((v:extend -) - (apply f (apply axis-add xs step indices-in)) - (apply f (apply axis-add xs (- step) indices-in))) - (* 2 step)))) - (for-indices-in-range - (lambda indices-out - (apply - array-set! - a - (apply array-ref dfxs indices-out) - (append indices-out indices-in))) - (list-zeros (array-rank fxs)) - (array-dimensions fxs)))) - (list-zeros (array-rank x)) - (array-dimensions x)) - a)))))) - -(test-begin "grad") - -;; extended operations -(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity)) -(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp)) -(test-assert - (receive (gx gy) (random-shared) - (with-generators (gx gy) ~ (v:extend *) v:*))) - -(define* (apply-grad grad-func #:optional (axis 0)) - (lambda (f . args) - (apply (grad-func f axis) args))) - -(test-assert - (with-generators - (random-func1 random-input) - ~ (apply-grad ngrad) (apply-grad v:grad))) - -;; `v:mean' only takes non-empty arrays so we treat it separately -(test-assert - (with-generators - ((differentiable-func-generator (list v:mean) random-non-empty-array) - random-non-empty-array) - ~ (apply-grad ngrad) (apply-grad v:grad))) - -(test-assert - (receive (gx gy) (random-shared) - (with-generators - (random-func2 gx gy) - ~ (apply-grad ngrad 0) (apply-grad v:grad 0)))) -(test-assert - (receive (gx gy) (random-shared) - (with-generators - (random-func2 gx gy) - ~ (apply-grad ngrad 1) (apply-grad v:grad 1)))) - -;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it -;; separately -(define* (apply-grad-amap2 grad-func #:optional (axis 0)) - (lambda (f x y) - ((grad-func - (lambda (x y) - (v:amap2 f x y)) - axis) - x y))) -(define random-func2-rank&dims>0 - (receive (gx gy) (random-shared-array-rank&dims>0) - (differentiable-func-generator f2s gx gy))) -(test-assert - (receive (gx gy) (random-shared-array-rank&dims>0) - (with-generators - (random-func2-rank&dims>0 gx gy) - ~ (apply-grad-amap2 ngrad 0) (apply-grad-amap2 v:grad 0)))) -(test-assert - (receive (gx gy) (random-shared-array-rank&dims>0) - (with-generators - (random-func2-rank&dims>0 gx gy) - ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1)))) -(test-assert - (~ ((ngrad v:amap2 1) v:* #(1 2 3) #(10 20 30)) - ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30)))) -(test-assert - (let ((x #(10 20 30)) - (y #(10 20 30))) - (~ ((ngrad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)) - ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))))) - -;; `v:adot' -(define (random-shared-contractible) - "Returns three generators: the first two generate arrays that are contractible -according to the number generated by the third one." - (let* ((n (random 5)) - (sa (random-array-shape n)) - (sb (append (reverse (take (reverse sa) - n)) - (random-array-shape 0 (- 5 n))))) - (values - (lambda () - (random-array sa)) - (lambda () - (random-array sb)) - (lambda () - (let ((tmp n)) - (set! n (random 5)) - (set! sa (random-array-shape n)) - (set! sb (append (reverse (take (reverse sa) - n)) - (random-array-shape 0 (- 5 n)))) - tmp))))) -(test-assert - (receive (gx gy gz) (random-shared-contractible) - (with-generators - (gx gy gz) - ~ - (lambda (a b n) ((ngrad v:adot 0) a b n)) - (lambda (a b n) ((v:grad v:adot 0) a b n))))) -(test-assert - (receive (gx gy gz) (random-shared-contractible) - (with-generators - (gx gy gz) - ~ - (lambda (a b n) ((ngrad v:adot 1) a b n)) - (lambda (a b n) ((v:grad v:adot 1) a b n))))) - -;; chain rule -(test-assert - (with-generators - (random-func1 random-func1 random-input) - ~ - (lambda (f g x) - (let* ((gx (g x)) - (r - (v:dot ((v:grad f) gx) - ((v:grad g) x) - (v:rank-of gx)))) - (ifn (and (number? (f (g x))) - (number? x) - (array? r)) - r - (array-ref r)))) - (lambda (f g x) - ((v:grad (compose f g)) x)))) - -(test-end "grad") diff --git a/grad.scm b/grad.scm deleted file mode 100644 index 9e4ec76..0000000 --- a/grad.scm +++ /dev/null @@ -1,699 +0,0 @@ -(define-module (vouivre grad) - #:use-module (ice-9 receive) - #:use-module (srfi srfi-1) - #:use-module (srfi srfi-9) - #:use-module (vouivre misc) - #:use-module (vouivre promises) - #:export - (*atype* - adot - amap2 - contract-arrays - differentiable-wrapper - dot - do-times - ewise1 - ewise2 - extend - grad - make-batch - make-internal - maximum - mean - rank-of - sum) - #:replace - ((i:sqrt . sqrt) - (i:exp . exp) - (i:expt . expt) - (i:log . log) - (i:sin . sin) - (i:cos . cos) - (i:tan . tan) - (i:+ . +) - (i:- . -) - (i:* . *) - (i:/ . /) - (i:max . max) - (i:min . min) - (i:abs . abs) - (i:identity . identity) - (i:array-ref . array-ref) - (i:array-cell-ref . array-cell-ref)) - #:re-export - (fold - reduce)) - -;;;; array utilities - -(define (rel->abs indices dimensions) - (let rec ((s 0) - (p 1) - (is (reverse indices)) - (ds (reverse dimensions))) - (if (null? is) - s - (rec (+ s (* p (car is))) - (* p (car ds)) - (cdr is) - (cdr ds))))) - -(define(do-times n proc) - (let rec ((i 0)) - (unless (= i n) - (proc i) - (rec (1+ i))))) - -(define (contract-arrays a b n) - (let* ((dims-a (array-dimensions a)) - (dims-b (array-dimensions b)) - (free-dims-a (take dims-a (- (array-rank a) n))) - (free-dims-b (drop dims-b n)) - (bound-dims (take dims-b n)) - (n-free-dims-a (apply * free-dims-a)) - (n-free-dims-b (apply * free-dims-b)) - (n-bound-dims (apply * bound-dims)) - (s 0) - (r (apply make-typed-array *atype* *unspecified* (append free-dims-a - free-dims-b))) - (ac (array-contents a)) - (bc (array-contents b)) - (rc (array-contents r))) - (do-times - n-free-dims-a - (lambda (i) - (let ((i-k (* n-bound-dims i)) - (i-j (* n-free-dims-b i))) - (do-times - n-free-dims-b - (lambda (j) - (set! s 0) - (do-times - n-bound-dims - (lambda (k) - (set! s (+ s (* (array-ref ac (+ i-k k)) - (array-ref bc (+ (* n-free-dims-b k) j))))))) - (array-set! rc s (+ i-j j))))))) - r)) - -;;;; utilities that work on both numbers and arrays - -(define (extend f) - "Extend a function of one or more scalars to apply to numbers/arrays -element-wise. All arrays must have the same dimension." - (define (apply-elemwise f indices args) - (apply f (map (lambda (x) - (if (number? x) - x - (apply array-ref x indices))) - args))) - (lambda xs - (if-let (x (find array? xs)) - (apply - produce-typed-array - (lambda is - (apply-elemwise f is xs)) - *atype* - (array-dimensions x)) - (apply f xs)))) - -(define (dot x y n) - (cond - ((and (number? x) (number? y)) - (* x y)) - ((and (array? x) (array? y)) - (contract-arrays x y n)) - ((and (array? x) (number? y)) - ((extend *) x y)) - ((and (number? x) (array? y)) - ((extend *) x y)) - (else (error "can't dot because of invalid types or ranks" x y n)))) - -(define (rank-of x) - (if (number? x) - 0 - (array-rank x))) - -;;;; differentiation - -(define-record-type internal - (make-internal forward jacobian) - internal? - (forward internal-forward) - (jacobian internal-jacobian)) - -;;(define *atype* 'f32) -(define *atype* #t) -(define *grad* (make-parameter #f)) -(define *j* (make-parameter #f)) - -(define-syntax-rule (w/j val body ...) - (parameterize ((*j* val)) - body ...)) - -(define (unwrap-fwd x) - (if (internal? x) - (internal-forward x) - x)) - -(define (unwrap-jac x) - (if (internal? x) - (internal-jacobian x) - x)) - -(define (dims-of x) - (if (number? x) - '() - (array-dimensions x))) - -(define (mov dst-buf n-dst-dims generator naked-inputs data j) - (do-times - n-dst-dims - (lambda (i) - (array-set! - dst-buf - (apply generator naked-inputs i j data) - i))) - dst-buf) - -(define (add dst-buf n-dst-dims generator naked-inputs data j) - (do-times - n-dst-dims - (lambda (i) - (array-set! - dst-buf - (+ (array-ref dst-buf i) - (apply generator naked-inputs i j data)) - i))) - dst-buf) - -(define (movc dst-buf n-dst-dims src-buf n-src-dims - generator naked-inputs data) - "Contract the Jacobian column produced by the generator with the source buffer -storing the result in the destination buffer." - (let ((s 0)) - (do-times - n-dst-dims - (lambda (i) - (set! s 0) - (do-times - n-src-dims - (lambda (k) - (set! s (+ s (* (apply generator naked-inputs i k data) - (array-ref src-buf k)))))) - (array-set! dst-buf s i)))) - dst-buf) - -(define (addc dst-buf n-dst-dims src-buf n-src-dims - generator naked-inputs data) - "Contract the Jacobian column produced by the generator with the source buffer -adding the result to the destination buffer." - (let ((s 0)) - (do-times - n-dst-dims - (lambda (i) - (set! s (array-ref dst-buf i)) - (do-times - n-src-dims - (lambda (k) - (set! s (+ s (* (apply generator naked-inputs i k data) - (array-ref src-buf k)))))) - (array-set! dst-buf s i)))) - dst-buf) - -(define* (grad f #:optional (axis 0)) - (define (wrap x i) - (if (= i axis) - (make-internal x 'input) - x)) - (lambda xs - (parameterize (((@@ (vouivre grad) *grad*) #t) - ((@@ (vouivre grad) *promises*) (cons '() #f))) - (let* ((internal (apply f (map-indexed wrap xs))) - (fx (internal-forward internal)) - (y (list-ref xs axis)) ; variable to differentiate w.r.t - (pre-Jx (internal-jacobian internal)) - (Jx (cond - ;; TODO: implement 'input case and test 'zero and 'input - ((eq? pre-Jx 'zero) - (lambda (j) - (lambda (i) - 0))) - ((eq? pre-Jx 'input) - (error "TBD.")) - (else - (lambda (j) - (reset-promises (car (*promises*))) - (let ((column-jac (w/j j (force pre-Jx)))) - (lambda (i) - (array-ref column-jac i)))))))) - (cond - ((and (number? fx) (number? y)) - ((Jx 0) 0)) - ((and (number? fx) (array? y)) - (let* ((y-dims (array-dimensions y)) - (a (apply make-array *unspecified* y-dims)) - (ac (array-contents a))) - (do-times - (apply * y-dims) - (lambda (j) - (array-set! ac ((Jx j) 0) - j))) - a)) - ((and (array? fx) (number? y)) - (let* ((fx-dims (array-dimensions fx)) - (a (apply make-array *unspecified* fx-dims)) - (ac (array-contents a)) - (Jx (Jx 0))) - (do-times - (apply * fx-dims) - (lambda (i) - (array-set! ac (Jx i) - i))) - a)) - (else - (let* ((fx-dims (array-dimensions fx)) - (y-dims (array-dimensions y)) - (n-fx-dims (apply * fx-dims)) - (n-y-dims (apply * y-dims)) - (a (apply make-array *unspecified* (append fx-dims y-dims))) - (ac (array-contents a))) - (do-times - n-y-dims - (lambda (j) - (let ((Jx (Jx j))) - (do-times - n-fx-dims - (lambda (i) - (array-set! ac (Jx i) - (+ j (* n-y-dims i)))))))) - a))))))) - - -;; In the comment that follows: -;; -;; `n' is the number of arguments to `proc'. -;; `generators is not a `Vec' but a `List' we only use the former to illustrate -;; its length. -;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension. -;; `I' is the type of multi-indices indexing the output of `function'. -;; `J' is the type of multi-indices indexing the input array being differentiated. -;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J'). -;; `Array I' is the type of arrays indexed by multi-indices of `I'. -;; `[X]' means that `X' is boxed in an internal as when returned by -;; `differentiable-wrapper' with the array being `X' and the promise that -;; given a |J| we will get the change of `X' with a change of the -;; the differentiated argument at multi-index `J'. -;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number)) -;; (→ X1 ... Xn (Array I))* -;; [X1] ... [Xn] -;; (Internal (Array I) (Promise |J| (Array |I|))))) -;; -;; (*) We extend this definition to allow `proc' to be a list of procedures -;; the head of which is as described above and the remaining elements -;; are procedures of the same arguments but returning values that are -;; then fed as extra data to the generators. -;; -;; NOTE: In cases where an argument isn't meant to be differentiable its -;; corresponding generator should be `#f'. -(define (differentiable-wrapper generators proc* arg . more) - (let* ((args (cons arg more)) - (proc (if (procedure? proc*) - proc* - (car proc*))) - (naked-args (map unwrap-fwd args)) - (out (apply proc naked-args))) - (ifn (*grad*) - out - (let* ((data (if (procedure? proc*) - '() - (map (lambda (g) - (apply g naked-args)) - (cdr proc*)))) - (n-out-dims (apply * (dims-of out))) - (buf (make-array *unspecified* n-out-dims))) - (make-internal - out - (fold - (lambda (generator arg prev) - (if (or (not (internal? arg)) - (eq? 'zero (internal-jacobian arg))) - prev - (let ((Jx (internal-jacobian arg)) - (n-fwd-dims (apply * (dims-of (unwrap-fwd arg))))) - (if (eq? Jx 'input) - (if (eq? prev 'zero) - (delay - (mov buf n-out-dims - generator naked-args data (*j*))) - (delay - (add (force prev) n-out-dims - generator naked-args data (*j*)))) - (if (eq? prev 'zero) - (delay - (movc buf n-out-dims (force Jx) n-fwd-dims - generator naked-args data)) - (delay - (addc (force prev) n-out-dims - (force Jx) n-fwd-dims - generator naked-args data))))))) - 'zero generators args)))))) - -(define (ewise1 f) - (lambda (xs i j) - (let ((x (car xs))) - (if (number? x) - (f x) - (ifn (= i j) - 0 - (f (array-ref (array-contents x) - j))))))) - -(define (ewise2 proc axis) - (lambda (xs i j) - (let ((x (car xs)) - (y (cadr xs))) - (cond - ((and (number? x) (number? y)) - (proc x y)) - ((and (number? x) (array? y)) - (if (= axis 0) - (proc x (array-ref (array-contents y) - i)) - (ifn (= i j) - 0 - (proc x (array-ref (array-contents y) - j))))) - ((and (array? x) (number? y)) - (if (= axis 1) - (proc (array-ref (array-contents x) - i) - y) - (ifn (= i j) - 0 - (proc (array-ref (array-contents x) - j) - y)))) - (else - (ifn (= i j) - 0 - (proc (array-ref (array-contents x) - j) - (array-ref (array-contents y) - j)))))))) - -(define (i:identity x) - "Differentiable identity." - (differentiable-wrapper - (list (ewise1 (lambda _ 1))) - identity - x)) - -(define (i:sqrt x) - "Differentiable square root." - (differentiable-wrapper - (list (ewise1 (lambda (x) (/ 1 2 (sqrt x))))) - (extend sqrt) - x)) - -(define (i:exp x) - "Differentiable exponential." - (differentiable-wrapper - (list (ewise1 exp)) - (extend exp) - x)) - -(define (i:expt x y) - "Differentiable power." - (differentiable-wrapper - (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0) - (ewise2 (lambda (x y) (* (expt x y) (log x))) 1)) - (extend expt) - x y)) - -(define (i:log x) - "Differentiable logarithm." - (differentiable-wrapper - (list (ewise1 (lambda (x) (/ x)))) - (extend log) - x)) - -(define (i:sin x) - "Differentiable sine." - (differentiable-wrapper - (list (ewise1 cos)) - (extend sin) - x)) - -(define (i:cos x) - "Differentiable cosine." - (differentiable-wrapper - (list (ewise1 (lambda (x) (- (sin x))))) - (extend cos) - x)) - -(define (i:tan x) - "Differentiable tangent." - (differentiable-wrapper - (list (ewise1 (lambda (x) (/ (expt (cos x) 2))))) - (extend tan) - x)) - -(define (i:+ x y) - "Differentiable element-wise addition." - (differentiable-wrapper - (list - (ewise2 (lambda _ +1) 0) - (ewise2 (lambda _ +1) 1)) - (extend +) - x y)) - -(define (i:- x y) - "Differentiable element-wise subtraction." - (differentiable-wrapper - (list - (ewise2 (lambda _ +1) 0) - (ewise2 (lambda _ -1) 1)) - (extend -) - x y)) - -(define (i:* x y) - "Differentiable element-wise multiplication." - (differentiable-wrapper - (list - (ewise2 (lambda (x y) y) 0) - (ewise2 (lambda (x y) x) 1)) - (extend *) - x y)) - -(define (i:/ x y) - "Differentiable element-wise division." - (differentiable-wrapper - (list - (ewise2 (lambda (x y) (/ y)) 0) - (ewise2 (lambda (x y) (- (/ x y y))) 1)) - (extend /) - x y)) - -(define (i:max x y) - "Differentiable element-wise maximum." - (define (dmax x y) - (cond - ((> x y) - 1) - ((= x y) - 1/2) - (else - 0))) - (differentiable-wrapper - (list - (ewise2 dmax 0) - (ewise2 (flip dmax) 1)) - (extend max) - x y)) - -(define (i:min x y) - "Differentiable element-wise minimum." - (define (dmin x y) - (cond - ((< x y) - 1) - ((= x y) - 1/2) - (else - 0))) - (differentiable-wrapper - (list - (ewise2 dmin 0) - (ewise2 (flip dmin) 1)) - (extend min) - x y)) - -(define (i:abs x) - "Differentiable absolute." - (differentiable-wrapper - (list (ewise1 (lambda (x) - (cond ((> x 0) - +1) - ((= x 0) - 1/2) - ((< x 0) - -1))))) - (extend abs) - x)) - -(define (mean x) - "Differentiable mean on arrays." - (differentiable-wrapper - (list - (lambda (xs i j one-over-n) - one-over-n)) - (let ((n 0)) - (list - (lambda (x) - (let ((sum 0)) - (array-for-each - (lambda (x) - (set! sum (+ sum x)) - (set! n (1+ n))) - x) - (/ sum n))) - (lambda _ (/ n)))) - x)) - -(define (i:array-ref x . indices) - "Differentiable array-ref w.r.t `x'." - (apply - differentiable-wrapper - (cons - (lambda (xs i j abs-index) - (if (= j abs-index) - 1 - 0)) - (map not indices)) - (list - array-ref - (lambda (x . indices) - (rel->abs indices (array-dimensions x)))) - x indices)) - -(define (i:array-cell-ref x . indices) - (apply - differentiable-wrapper - (cons - (lambda (xs i j abs-index n-rst-dims) - (receive (j-ref j-rst) (euclidean/ j n-rst-dims) - (if (and (= j-ref abs-index) - (= j-rst i)) - 1 - 0))) - (map not indices)) - (list - array-cell-ref - (lambda (x . indices) - (rel->abs indices (take (array-dimensions x) - (length indices)))) - (lambda (x . indices) - (apply * (drop (array-dimensions x) - (length indices))))) - x indices)) - -(define (make-batch elem . more) - (let ((batch-size (1+ (length more)))) - (apply - differentiable-wrapper - (list-tabulate - batch-size - (lambda (b) - (lambda (xs i j n-rest-dims) - (receive (i-batch i-rest) (euclidean/ i n-rest-dims) - (if (and (= i-batch b) - (= i-rest j)) - 1 - 0 - ))))) - (list - (lambda (elem . more) - (let ((a (apply make-typed-array *atype* *unspecified* batch-size - (dims-of elem)))) - (for-each - (lambda (x b) - (array-cell-set! a x b)) - (cons elem more) - (list-tabulate batch-size identity)) - a)) - (lambda (elem . more) - (apply * (dims-of elem)))) - elem more))) - -(define (maximum x) - "Differentiable maximum on arrays." - (differentiable-wrapper - (list - (lambda (xs i j max-index) - (if (= j max-index) - 1 - 0))) - (let ((max-index 'TBD)) - (list - (lambda (x) - (let ((m (- (inf))) - (i 0)) - (array-for-each - (lambda (x) - (when (< m x) - (set! m x) - (set! max-index i)) - (set! i (1+ i))) - x) - m)) - (lambda _ max-index))) - x)) - -(define (sum x) - "Differentiable sum on arrays." - (differentiable-wrapper - (list (lambda _ 1)) - (lambda (x) - (let ((sum 0)) - (array-for-each - (lambda (x) - (set! sum (+ sum x))) - x) - sum)) - x)) - -(define (adot x y n) - (differentiable-wrapper - (list - (lambda (xs i j n-free-dims-y n-bound-dims) - (receive (i-x i-y) (euclidean/ i n-free-dims-y) - (receive (j-free j-bound) (euclidean/ j n-bound-dims) - (ifn (= i-x j-free) - 0 - (array-ref (array-contents (cadr xs)) - (+ i-y (* n-free-dims-y j-bound))))))) - (lambda (xs i j n-free-dims-y n-bound-dims) - (receive (i-x i-y) (euclidean/ i n-free-dims-y) - (receive (j-bound j-free) (euclidean/ j n-free-dims-y) - (ifn (= i-y j-free) - 0 - (array-ref (array-contents (car xs)) - (+ j-bound (* n-bound-dims i-x))))))) - #f) - (list - contract-arrays - (lambda (x y n) - (apply * (drop (array-dimensions y) - n))) - (lambda (x y n) - (apply * (take (array-dimensions y) - n)))) - x y n)) - -(define (amap2 f x y) - (apply make-batch - (list-tabulate (car (dims-of (unwrap-fwd x))) - (lambda (b) - (f (i:array-cell-ref x b) - (i:array-cell-ref y b))))))