(define-module (vouivre grad)
+ #:use-module (ice-9 receive)
#:use-module (srfi srfi-1)
#:use-module (srfi srfi-9)
#:use-module (vouivre misc)
grad
internal-jacobian
mean
- mirror
- one
rank-of)
#:replace
((i:* . *)
(i:identity . identity)
(i:max . max)))
-;;;; misc utilities
-
-(define (one . args)
- 1)
-
;;;; array utilities
(define (contracted-dims a b n)
((extend *) x y))
(else (error "can't dot because of invalid types or ranks" x y n))))
-(define (mirror f x . args)
- "Create a scalar/array of shape [x]:[x] where element with index `is:is' has
-the value of `f' evaluated at the `is' elements of `args' and all other elements
-being zero."
- (cond
- ((number? x)
- (apply f args))
- ((array? x)
- (let ((n (array-rank x))
- (dims (array-dimensions x)))
- (apply
- produce-array
- (lambda indices
- (if (equal? (take indices n)
- (drop indices n))
- (apply f (map (lambda (arg)
- (apply array-ref arg (drop indices n)))
- args))
- 0))
- (append dims dims))))
- (else (error "expected array or number, got " x))))
-
(define (rank-of x)
(if (number? x)
0
;;;; differentiation
(define-record-type internal
- (make-internal jacobian forward)
+ (make-internal forward jacobian)
internal?
- (jacobian internal-jacobian set-internal-jacobian!)
- (forward internal-forward))
+ (forward internal-forward)
+ (jacobian internal-jacobian))
(define *grad* (make-parameter #f))
(define* (grad f #:optional (axis 0))
(define (wrap x i)
(if (= i axis)
- (make-internal (mirror one x) x)
+ (make-internal x #f)
x))
(lambda xs
- ((if (*grad*) identity internal-jacobian)
- (let ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity))))
- (parameterize (((@@ (vouivre grad) *grad*) (list-ref wrapped-xs axis)))
+ (let* ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity)))
+ (y (internal-forward (list-ref wrapped-xs axis))))
+ ((if (*grad*)
+ identity
+ (lambda (boxed)
+ (let ((fx (internal-forward boxed))
+ (gfx (internal-jacobian boxed)))
+ (cond
+ ((and (number? fx) (number? y))
+ (gfx '()))
+ ((and (number? fx) (array? y))
+ (apply produce-array
+ (lambda js (gfx js))
+ (array-dimensions y)))
+ ((and (array? fx) (number? y))
+ (gfx '()))
+ (else
+ (let* ((dims-out (array-dimensions fx))
+ (dims-in (array-dimensions y))
+ (a (apply make-array 0 (append dims-out dims-in))))
+ (for-indices-in-range
+ (lambda js
+ (let ((out (gfx js)))
+ (for-indices-in-range
+ (lambda is
+ (apply
+ array-set!
+ a
+ (apply array-ref out is)
+ (append is js)))
+ (list-zeros (array-rank fx))
+ dims-out)))
+ (list-zeros (array-rank y))
+ dims-in)
+ a)
+ )))))
+ (parameterize (((@@ (vouivre grad) *grad*) y))
(apply f wrapped-xs))))))
-(define (grad-input)
- (internal-forward (*grad*)))
+(define (unbox-fwd x)
+ (if (internal? x)
+ (internal-forward x)
+ x))
-(define (differentiable-wrapper jacobian-generator function input . more)
- ;; NOTE: Both the jacobian generator and the function act on naked inputs
- ;; (numbers or arrays not inside an internal object). The generator
- ;; returns a list of jacobians, each one expressing the change of the
- ;; function's output when changing the corresponding input. In cases
- ;; where an argument isn't meant to be differentiable its corresponding
- ;; element in the list should be `#f'.
+;; `n' is the number of arguments to `function'.
+;; `jacobian-generators is not a `Vec' but a `List' we only use the former to
+;; show its length.
+;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
+;; `I' is the type of multi-indices indexing the output of `function'.
+;; `J' is the type of multi-indices indexing the input array being differentiated.
+;; `Array I' is the type of arrays indexed by multi-indices of `I'.
+;; `[X]' means that `X' is boxed in an internal as when returned by
+;; `differentiable-wrapper' with the array being `X' and the function
+;; taking a multi-index of `J' to return an array of the same shape as `X'.
+;; (∷ (→ (Vec n (→ X1 ... Xn I J Number))
+;; (→ X1 ... Xn (Array I))
+;; [X1] ... [Xn]
+;; (Internal (Array I) (→ J (Array I)))))
+(define (differentiable-wrapper jacobian-generators function input . more)
+ ;; NOTE: Both the jacobian generators and the function act on naked inputs
+ ;; (numbers or arrays not inside an internal object). The generators
+ ;; additionaly take indices -- one for each dimension of what we are
+ ;; differentiating with respect to (the content of `(*grad*)'). A
+ ;; generator returns a jacobian column expressing the change of the
+ ;; function's output when changing `(*grad*)' along the given indices.
+ ;; In cases where an argument isn't meant to be differentiable its
+ ;; corresponding generator should be `#f'.
(let* ((inputs (cons input more))
- (naked-inputs
- (map
- (lambda (x)
- (if (internal? x)
- (internal-forward x)
- x))
- inputs))
+ (naked-inputs (map unbox-fwd inputs))
(output (apply function naked-inputs)))
- (if (*grad*)
- (make-internal
- (or
- (fold
- (lambda (Jx x prev)
- (if (not (internal? x))
- prev
- ((if (not prev)
- identity
- (lambda (x) ((extend +) prev x)))
- (dot Jx
- (internal-jacobian x)
- (rank-of (internal-forward x))))))
- #f
- (apply jacobian-generator naked-inputs)
- inputs)
- (zeros-out:in output (grad-input)))
- output)
- output)))
+ (ifn (*grad*)
+ output
+ (make-internal
+ output
+ (lambda (js)
+ (or
+ (fold
+ (lambda (jacobian-generator input prev)
+ (ifn (internal? input)
+ prev
+ ((ifn prev identity (lambda (x) ((extend +) prev x)))
+ (ifn (internal-jacobian input)
+ (if (number? output)
+ (apply jacobian-generator
+ (append naked-inputs js))
+ (apply
+ produce-array
+ (lambda is
+ (apply jacobian-generator
+ (append naked-inputs is js)))
+ (array-dimensions output)))
+ (let ((b ((internal-jacobian input) js))
+ (fwd (internal-forward input)))
+ (cond
+ ((and (number? output) (number? fwd))
+ (* (apply jacobian-generator naked-inputs)
+ b))
+ ((and (number? output) (array? fwd))
+ (array-ref
+ (contract-arrays
+ (apply
+ produce-array
+ (lambda ks
+ (apply jacobian-generator
+ (append naked-inputs ks)))
+ (array-dimensions fwd))
+ b (array-rank fwd))))
+ ((and (array? output) (number? fwd))
+ (apply
+ produce-array
+ (lambda is
+ ((extend *)
+ (apply jacobian-generator
+ (append naked-inputs is))
+ b))
+ (array-dimensions output)))
+ (else
+ (apply
+ produce-array
+ (lambda is
+ (array-ref
+ (contract-arrays
+ (apply
+ produce-array
+ (lambda ks
+ (apply jacobian-generator
+ (append naked-inputs is ks)))
+ (array-dimensions fwd))
+ b (array-rank fwd))))
+ (array-dimensions output)))))))))
+ #f jacobian-generators inputs)
+ (if (number? output)
+ 0
+ (apply make-array 0 (array-dimensions output)))))))))
+
+(define (ewise1 f)
+ (lambda (x . indices)
+ (if (number? x)
+ (f x)
+ (receive (is js) (split-at indices (array-rank x))
+ (if (equal? is js)
+ (f (apply array-ref x is))
+ 0)))))
+
+(define (ewise2 proc axis)
+ (lambda (x y . indices)
+ (cond
+ ((and (number? x) (number? y))
+ (proc x y))
+ ((and (number? x) (array? y))
+ (if (= axis 0)
+ (proc x (apply array-ref y indices))
+ (receive (is js) (split-at indices (array-rank y))
+ (ifn (equal? is js)
+ 0
+ (proc x (apply array-ref y is))))))
+ ((and (array? x) (number? y))
+ (if (= axis 1)
+ (proc (apply array-ref x indices) y)
+ (receive (is js) (split-at indices (array-rank x))
+ (ifn (equal? is js)
+ 0
+ (proc (apply array-ref x is)
+ y)))))
+ (else
+ (receive (is js) (split-at indices (array-rank x))
+ (ifn (equal? is js)
+ 0
+ (proc (apply array-ref x is)
+ (apply array-ref y is))))))))
(define (i:identity x)
"Differentiable identity."
(differentiable-wrapper
- (lambda (x) (list (mirror one x)))
+ (list (ewise1 (lambda _ 1)))
identity
x))
(define (i:exp x)
"Differentiable exponential."
(differentiable-wrapper
- (lambda (x) (list (mirror exp x x)))
+ (list (ewise1 exp))
(extend exp)
x))
(define (i:* x y)
"Differentiable element-wise multiplication."
(differentiable-wrapper
- (lambda (x y)
- (list
- (cond
- ((and (number? y) (number? x))
- y)
- ((and (number? y) (array? x))
- (mirror (lambda _ y) x))
- ((and (array? y) (number? x))
- y)
- (else
- (mirror identity x y)))
- (cond
- ((and (number? x) (number? y))
- x)
- ((and (number? x) (array? y))
- (mirror (lambda _ x) y))
- ((and (array? x) (number? y))
- x)
- (else
- (mirror identity y x)))))
+ (list
+ (ewise2 (lambda (x y) y) 0)
+ (ewise2 (lambda (x y) x) 1))
(extend *)
x y))
(define (i:- x y)
"Differentiable element-wise subtraction."
(differentiable-wrapper
- (lambda (x y)
- (list
- (cond
- ((and (number? y) (number? x))
- +1)
- ((and (number? y) (array? x))
- (mirror one x))
- ((and (array? y) (number? x))
- (apply make-array 1 (array-dimensions y)))
- (else
- (mirror one x)))
- (cond
- ((and (number? x) (number? y))
- -1)
- ((and (number? x) (array? y))
- (mirror (lambda _ -1) y))
- ((and (array? x) (number? y))
- (apply make-array -1 (array-dimensions x)))
- (else
- (mirror (lambda _ -1) y)))))
+ (list
+ (ewise2 (lambda _ +1) 0)
+ (ewise2 (lambda _ -1) 1))
(extend -)
x y))
(else
0)))
(differentiable-wrapper
- (lambda (x y)
- (list
- (cond
- ((and (number? y) (number? x))
- (dmax x y))
- ((and (number? y) (array? x))
- (mirror (lambda (x) (dmax x y)) x x))
- ((and (array? y) (number? x))
- (array-map (lambda (y) (dmax x y)) y))
- (else
- (mirror (lambda (x y) (dmax x y)) x x y)))
- (cond
- ((and (number? x) (number? y))
- (dmax y x))
- ((and (number? x) (array? y))
- (mirror (lambda (y) (dmax y x)) y y))
- ((and (array? x) (number? y))
- (array-map (lambda (x) (dmax y x)) x))
- (else
- (mirror (lambda (y x) (dmax y x)) y y x)))))
+ (list
+ (ewise2 dmax 0)
+ (ewise2 (flip dmax) 1))
(extend max)
x y))
(define (mean x)
"Differentiable mean on arrays."
(differentiable-wrapper
- (lambda (x)
- (let ((dims (array-dimensions x)))
- (list
- (apply
- make-array
- (/ 1 (apply * dims))
- dims))))
+ (list
+ (lambda (x . indices)
+ (/ 1 (apply * (array-dimensions x)))))
(lambda (x)
(let ((sum 0)
(count 0))
(/ sum count)))
x))
+;; ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30))
+;; (let ((x #2((1 2) (3 4) (5 6))) (y #2((1 2) (3 4) (5 6)))) ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)))
(define (amap2 f x y)
(define (unbox-with proc x)
(ifn (internal? x)
x
(proc x)))
+(define (dims-of x)
+ (if (number? x)
+ '()
+ (array-dimensions x)))
(define (boxed-ref x i)
(ifn (internal? x)
(array-cell-ref x i)
- (make-internal (array-cell-ref (internal-jacobian x)
+ (make-internal (array-cell-ref (internal-forward x)
i)
- (array-cell-ref (internal-forward x)
- i))))
+ (if (internal-jacobian x)
+ (lambda (js)
+ (array-cell-ref
+ ((internal-jacobian x)
+ js)
+ i))
+ (lambda (js)
+ (let* ((x (internal-forward x))
+ (xi (array-cell-ref x i)))
+ (if (number? xi)
+ (if (= i (car js))
+ 1
+ 0)
+ (let ((a (apply make-array 0 (dims-of xi))))
+ (for-indices-in-range
+ (lambda out
+ (when (and (= i (car js))
+ (equal? out (cdr js)))
+ (apply
+ array-set!
+ a 1
+ out)))
+ (list-zeros (rank-of xi))
+ (dims-of xi))
+ a))))))))
(define (boxed-fi f i x y)
(f (boxed-ref x i)
(boxed-ref y i)))
- (define (dims-of x)
- (if (number? x)
- '()
- (array-dimensions x)))
(if (internal? f)
(error "unsupported application of `amap2' to internal object" f)
(let ((bs (first (array-dimensions (unbox-with internal-forward x))))
(array-cell-set! fwd (boxed-fi f i x y) i))
'(0) (list bs))
fwd)
- (let ((jac (apply make-array 0 bs (dims-of (internal-jacobian f0))))
+ (let ((jac (make-array 0 bs))
(fwd (apply make-array 0 bs (dims-of (internal-forward f0)))))
(for-indices-in-range
(lambda (i)
- (let ((fi (boxed-fi f i x y)))
- (array-cell-set! jac (internal-jacobian fi) i)
+ (let* ((fi (boxed-fi f i x y))
+ (Jfi (internal-jacobian fi)))
+ (array-set! jac (internal-jacobian fi) i)
(array-cell-set! fwd (internal-forward fi) i)))
'(0) (list bs))
- (make-internal jac fwd))))))
+ (make-internal
+ fwd
+ (lambda (js)
+ (let ((a (apply make-array 0 bs (dims-of (internal-forward f0)))))
+ (for-indices-in-range
+ (lambda (batch)
+ (array-cell-set!
+ a ((array-ref jac batch) js)
+ batch))
+ '(0) (list bs))
+ a))))))))
(define (adot x y n)
(differentiable-wrapper
- (lambda (x y n)
- (let ((bound-dims (contracted-dims x y n)))
- (list
- (let ((r (apply make-array 0 (append bound-dims (array-dimensions x)))))
- (for-indices-in-range
- (lambda free-indices
- (let ((free-indices-x (take free-indices (- (array-rank x) n)))
- (free-indices-y (drop free-indices (- (array-rank x) n))))
- (for-indices-in-range
- (lambda bound-indices
- (apply
- array-set!
- r
- (apply array-ref y (append bound-indices free-indices-y))
- (append free-indices free-indices-x bound-indices)))
- (list-zeros n)
- (take (array-dimensions y) n))))
- (list-zeros (length bound-dims))
- bound-dims)
- r)
- (let ((r (apply make-array 0 (append bound-dims (array-dimensions y)))))
- (for-indices-in-range
- (lambda free-indices
- (let ((free-indices-x (take free-indices (- (array-rank x) n)))
- (free-indices-y (drop free-indices (- (array-rank x) n))))
- (for-indices-in-range
- (lambda bound-indices
- (apply
- array-set!
- r
- (apply array-ref x (append free-indices-x bound-indices))
- (append free-indices bound-indices free-indices-y)))
- (list-zeros n)
- (take (array-dimensions y) n))))
- (list-zeros (length bound-dims))
- bound-dims)
- r)
- #f)))
+ (list
+ (lambda (x y n . indices)
+ (let* ((free-rank-x (- (array-rank x)
+ n))
+ (free-rank-y (- (array-rank y)
+ n))
+ (out (take indices (+ free-rank-x free-rank-y)))
+ (in (drop indices (+ free-rank-x free-rank-y)))
+ (is-free1 (take out free-rank-x))
+ (js-free (drop out free-rank-x))
+ (is-free2 (take in free-rank-x))
+ (is-bound (drop in free-rank-x)))
+ (ifn (equal? is-free1 is-free2)
+ 0
+ (apply array-ref y (append is-bound js-free)))))
+ (lambda (x y n . indices)
+ (let* ((free-rank-x (- (array-rank x)
+ n))
+ (free-rank-y (- (array-rank y)
+ n))
+ (out (take indices (+ free-rank-x free-rank-y)))
+ (in (drop indices (+ free-rank-x free-rank-y)))
+ (is-free (take out free-rank-x))
+ (js-free1 (drop out free-rank-x))
+ (js-free2 (drop in n))
+ (js-bound (take in n)))
+ (ifn (equal? js-free1 js-free2)
+ 0
+ (apply array-ref x (append is-free js-bound)))))
+ #f)
contract-arrays x y n))