--- /dev/null
+(define-module (vouivre grad tests)
+ #:use-module ((vouivre grad) #:prefix v:)
+ #:use-module (ice-9 receive)
+ #:use-module (srfi srfi-1)
+ #:use-module (srfi srfi-64)
+ #:use-module (vouivre misc)
+ #:export
+ (apply-grad
+ apply-grad-amap2
+ a~
+ differentiable-func-generator
+ lambda-const-call
+ ngrad
+ n~
+ random-array
+ random-array-shape
+ random-func1
+ random-func2
+ random-input
+ random-list-element
+ random-non-empty-array
+ random-shape
+ random-shared
+ random-shared-array-rank&dims>0
+ with-generators
+ ~))
+
+(define f1s (list v:exp v:identity))
+(define f2s (list v:* v:- v:max))
+
+(define-syntax-rule (with-generators (g1 g2 ...) equal expected given)
+ (let ((times 100)
+ (fx expected)
+ (fy given)
+ (generators (list g1 g2 ...)))
+ (call/cc
+ (lambda (break)
+ (do ((i 0 (1+ i)))
+ ((= i times) #t)
+ (let ((zs (map-in-order (lambda (g) (g)) generators)))
+ (let ((r1 (apply fx zs))
+ (r2 (apply fy zs)))
+ (unless (equal r1 r2)
+ (break #f zs r1 r2)))))))))
+
+(define (lambda-const-call f . consts)
+ (lambda _
+ (apply f consts)))
+
+(define* (random-array-shape
+ #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5))
+ (list-tabulate (+ min-rank (random (- max-rank min-rank)))
+ (lambda _ (+ min-dim (random (- max-dim min-dim))))))
+
+(define (random-shape)
+ (if (= 0 (random 2))
+ 0
+ (random-array-shape)))
+
+(define* (random-array #:optional shape)
+ (let ((a (apply make-array 0 (or shape (random-array-shape)))))
+ (array-map! a random:uniform)
+ a))
+
+(define (random-non-empty-array)
+ "Random array of at least one element."
+ (random-array (random-array-shape 0 5 1 5)))
+
+(define* (random-input #:optional shape)
+ (let ((shape (or shape (random-shape))))
+ (if (eq? 0 shape)
+ (random:uniform)
+ (random-array shape))))
+
+(define (random-shared)
+ (let ((shape (random-shape)))
+ (values
+ (lambda ()
+ (random-input shape))
+ (lambda ()
+ (let ((x (random-input
+ (random-list-element
+ (list 0 (if (list? shape)
+ shape
+ (random-shape)))))))
+ (set! shape (random-shape))
+ x)))))
+
+(define (random-shared-array-rank&dims>0)
+ (let ((shape (random-array-shape 1 5 1 5)))
+ (values
+ (lambda ()
+ (random-array shape))
+ (lambda ()
+ (let ((x (random-array shape)))
+ (set! shape (random-array-shape 1 5 1 5))
+ x)))))
+
+(define (random-list-element lst)
+ (list-ref lst (random (length lst))))
+
+(define (differentiable-func-generator lst . input-generators)
+ (lambda ()
+ (random-list-element
+ (cons
+ (apply
+ lambda-const-call
+ (random-list-element lst)
+ (map (lambda (g) (g))
+ input-generators))
+ lst))))
+
+(define random-func1
+ (differentiable-func-generator f1s random-input))
+(define random-func2
+ (receive (gx gy) (random-shared)
+ (differentiable-func-generator f2s gx gy)))
+
+(define* (n~ x y #:optional (error 1e-4))
+ (and
+ (>= y (- x error))
+ (<= y (+ x error))))
+
+(define* (a~ x y #:optional (error 1e-4))
+ (and
+ (equal? (array-dimensions x)
+ (array-dimensions y))
+ (call/cc
+ (lambda (break)
+ (array-for-each
+ (lambda (x y)
+ (unless (~ x y error)
+ (break #f)))
+ x y)
+ #t))))
+
+(define* (~ x y #:optional (error 1e-4))
+ (cond
+ ((and (number? x) (number? y))
+ (n~ x y error))
+ ((and (array? x) (array? y))
+ (a~ x y error))
+ (else #f)))
+
+(define* (ngrad f #:optional (axis 0) (step 1e-6))
+ "Gradient using a numerical centered difference approximation."
+ (define (axis-add xs dh . indices)
+ "Add `dh' to the number or array at the given `axis' of `xs',
+and, when it's an array, at the given index."
+ (map-indexed
+ (lambda (x i)
+ (ifn (= i axis)
+ x
+ (if (number? x)
+ (+ x dh)
+ (array-map-indexed
+ (lambda (x . indices_)
+ (ifn (equal? indices indices_)
+ x
+ (+ x dh)))
+ x))))
+ xs))
+ (lambda xs
+ ;; We need the output shape and the input shape along the
+ ;; differentiated axis.
+ (let ((fxs (apply f xs))
+ (x (list-ref xs axis)))
+ (cond
+ ((and (number? fxs)
+ (number? x))
+ (/ (- (apply f (axis-add xs step))
+ (apply f (axis-add xs (- step))))
+ (* 2 step)))
+ ((and (number? fxs)
+ (array? x))
+ (apply
+ produce-array
+ (lambda indices
+ (/ (- (apply f (apply axis-add xs step indices))
+ (apply f (apply axis-add xs (- step) indices)))
+ (* 2 step)))
+ (array-dimensions x)))
+ ((and (array? fxs)
+ (number? x))
+ ((v:extend /)
+ ((v:extend -)
+ (apply f (axis-add xs step))
+ (apply f (axis-add xs (- step))))
+ (* 2 step)))
+ ((and (array? fxs)
+ (array? x))
+ (let ((a (apply
+ make-array
+ 0
+ (append (array-dimensions fxs)
+ (array-dimensions x)))))
+ (for-indices-in-range
+ (lambda indices-in
+ (let ((dfxs ((v:extend /)
+ ((v:extend -)
+ (apply f (apply axis-add xs step indices-in))
+ (apply f (apply axis-add xs (- step) indices-in)))
+ (* 2 step))))
+ (for-indices-in-range
+ (lambda indices-out
+ (apply
+ array-set!
+ a
+ (apply array-ref dfxs indices-out)
+ (append indices-out indices-in)))
+ (list-zeros (array-rank fxs))
+ (array-dimensions fxs))))
+ (list-zeros (array-rank x))
+ (array-dimensions x))
+ a))))))
+
+(test-begin "grad")
+
+;; extended operations
+(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity))
+(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp))
+(test-assert
+ (receive (gx gy) (random-shared)
+ (with-generators (gx gy) ~ (v:extend *) v:*)))
+
+(define* (apply-grad grad-func #:optional (axis 0))
+ (lambda (f . args)
+ (apply (grad-func f axis) args)))
+
+(test-assert
+ (with-generators
+ (random-func1 random-input)
+ ~ (apply-grad ngrad) (apply-grad v:grad)))
+
+;; `v:mean' only takes non-empty arrays so we treat it separately
+(test-assert
+ (with-generators
+ ((differentiable-func-generator (list v:mean) random-non-empty-array)
+ random-non-empty-array)
+ ~ (apply-grad ngrad) (apply-grad v:grad)))
+
+(test-assert
+ (receive (gx gy) (random-shared)
+ (with-generators
+ (random-func2 gx gy)
+ ~ (apply-grad ngrad 0) (apply-grad v:grad 0))))
+(test-assert
+ (receive (gx gy) (random-shared)
+ (with-generators
+ (random-func2 gx gy)
+ ~ (apply-grad ngrad 1) (apply-grad v:grad 1))))
+
+;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it
+;; separately
+(define* (apply-grad-amap2 grad-func #:optional (axis 0))
+ (lambda (f x y)
+ ((grad-func
+ (lambda (x y)
+ (v:amap2 f x y))
+ axis)
+ x y)))
+(define random-func2-rank&dims>0
+ (receive (gx gy) (random-shared-array-rank&dims>0)
+ (differentiable-func-generator f2s gx gy)))
+(test-assert
+ (receive (gx gy) (random-shared-array-rank&dims>0)
+ (with-generators
+ (random-func2-rank&dims>0 gx gy)
+ ~ (apply-grad-amap2 ngrad 0) (apply-grad-amap2 v:grad 0))))
+(test-assert
+ (receive (gx gy) (random-shared-array-rank&dims>0)
+ (with-generators
+ (random-func2-rank&dims>0 gx gy)
+ ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1))))
+
+;; `v:adot'
+(define (random-shared-contractible)
+ "Returns three generators: the first two generate arrays that are contractible
+according to the number generated by the third one."
+ (let* ((n (random 5))
+ (sa (random-array-shape n))
+ (sb (append (reverse (take (reverse sa)
+ n))
+ (random-array-shape 0 (- 5 n)))))
+ (values
+ (lambda ()
+ (random-array sa))
+ (lambda ()
+ (random-array sb))
+ (lambda ()
+ (let ((tmp n))
+ (set! n (random 5))
+ (set! sa (random-array-shape n))
+ (set! sb (append (reverse (take (reverse sa)
+ n))
+ (random-array-shape 0 (- 5 n))))
+ tmp)))))
+(test-assert
+ (receive (gx gy gz) (random-shared-contractible)
+ (with-generators
+ (gx gy gz)
+ ~
+ (lambda (a b n) ((ngrad v:adot 0) a b n))
+ (lambda (a b n) ((v:grad v:adot 0) a b n)))))
+
+;; chain rule
+(test-assert
+ (with-generators
+ (random-func1 random-func1 random-input)
+ ~
+ (lambda (f g x)
+ (let* ((gx (g x))
+ (r
+ (v:dot ((v:grad f) gx)
+ ((v:grad g) x)
+ (v:rank-of gx))))
+ (ifn (and (number? (f (g x)))
+ (number? x)
+ (array? r))
+ r
+ (array-ref r))))
+ (lambda (f g x)
+ ((v:grad (compose f g)) x))))
+
+(test-end "grad")
--- /dev/null
+(define-module (vouivre grad)
+ #:use-module (srfi srfi-1)
+ #:use-module (srfi srfi-9)
+ #:use-module (vouivre misc)
+ #:export
+ (adot
+ amap2
+ differentiable-wrapper
+ dot
+ extend
+ grad
+ internal-jacobian
+ mean
+ mirror
+ one
+ rank-of)
+ #:replace
+ ((i:* . *)
+ (i:- . -)
+ (i:exp . exp)
+ (i:fold . fold)
+ (i:identity . identity)
+ (i:max . max)))
+
+;;;; misc utilities
+
+(define (one . args)
+ 1)
+
+;;;; array utilities
+
+(define (contracted-dims a b n)
+ (let ((dims-a (array-dimensions a))
+ (dims-b (array-dimensions b)))
+ (if (or (> n (array-rank a))
+ (> n (array-rank b)))
+ (error "can't contract arrays with size lower than" n)
+ (append (reverse (drop (reverse dims-a)
+ n))
+ (drop dims-b n)))))
+
+(define (contract-arrays a b n)
+ (let* ((dims (contracted-dims a b n))
+ (dims-b (array-dimensions b))
+ (nb-dims-a (array-rank a))
+ (nb-dims-b (array-rank b))
+ (nb-fix-dims-a (- nb-dims-a n))
+ (r (apply make-array 0 dims)))
+ (for-indices-in-range
+ (lambda r-indices
+ (apply
+ array-set!
+ r
+ (let ((s 0))
+ (for-indices-in-range
+ (lambda free-indices
+ (set! s (+ s (* (apply
+ array-ref
+ a
+ (append (take r-indices nb-fix-dims-a)
+ free-indices))
+ (apply
+ array-ref
+ b
+ (append free-indices
+ (drop r-indices nb-fix-dims-a)))))))
+ (list-zeros n)
+ (take dims-b n))
+ s)
+ r-indices))
+ (list-zeros (length dims))
+ dims)
+ r))
+
+;;;; utilities that work on both numbers and arrays
+
+(define (extend f)
+ "Extend a function of one or more scalars to apply to numbers/arrays
+element-wise. All arrays must have the same dimension."
+ (define (apply-elemwise f indices args)
+ (apply f (map (lambda (x)
+ (if (number? x)
+ x
+ (apply array-ref x indices)))
+ args)))
+ (lambda xs
+ (if-let (x (find array? xs))
+ (apply
+ produce-array
+ (lambda is
+ (apply-elemwise f is xs))
+ (array-dimensions x))
+ (apply f xs))))
+
+(define (dot x y n)
+ (cond
+ ((and (number? x) (number? y))
+ (* x y))
+ ((and (array? x) (array? y))
+ (contract-arrays x y n))
+ ((and (array? x) (number? y))
+ ((extend *) x y))
+ ((and (number? x) (array? y))
+ ((extend *) x y))
+ (else (error "can't dot because of invalid types or ranks" x y n))))
+
+(define (mirror f x . args)
+ "Create a scalar/array of shape [x]:[x] where element with index `is:is' has
+the value of `f' evaluated at the `is' elements of `args' and all other elements
+being zero."
+ (cond
+ ((number? x)
+ (apply f args))
+ ((array? x)
+ (let ((n (array-rank x))
+ (dims (array-dimensions x)))
+ (apply
+ produce-array
+ (lambda indices
+ (if (equal? (take indices n)
+ (drop indices n))
+ (apply f (map (lambda (arg)
+ (apply array-ref arg (drop indices n)))
+ args))
+ 0))
+ (append dims dims))))
+ (else (error "expected array or number, got " x))))
+
+(define (rank-of x)
+ (if (number? x)
+ 0
+ (array-rank x)))
+
+(define (zeros-out:in x y)
+ "Zeros in the shape of [x]:[y]."
+ (cond
+ ((and (number? x)
+ (number? y))
+ 0)
+ ((and (array? x)
+ (array? y))
+ (apply make-array 0 (append (array-dimensions x)
+ (array-dimensions y))))
+ ((and (number? x)
+ (array? y))
+ (apply make-array 0 (array-dimensions y)))
+ ((and (array? x)
+ (number? y))
+ (apply make-array 0 (array-dimensions x)))
+ (else (error "undefined." x y))))
+
+;;;; differentiation
+
+(define-record-type internal
+ (make-internal jacobian forward)
+ internal?
+ (jacobian internal-jacobian set-internal-jacobian!)
+ (forward internal-forward))
+
+(define *grad* (make-parameter #f))
+
+(define* (grad f #:optional (axis 0))
+ (define (wrap x i)
+ (if (= i axis)
+ (make-internal (mirror one x) x)
+ x))
+ (lambda xs
+ ((if (*grad*) identity internal-jacobian)
+ (let ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity))))
+ (parameterize (((@@ (vouivre grad) *grad*) (list-ref wrapped-xs axis)))
+ (apply f wrapped-xs))))))
+
+(define (grad-input)
+ (internal-forward (*grad*)))
+
+(define (differentiable-wrapper jacobian-generator function input . more)
+ ;; NOTE: Both the jacobian generator and the function act on naked inputs
+ ;; (numbers or arrays not inside an internal object). The generator
+ ;; returns a list of jacobians, each one expressing the change of the
+ ;; function's output when changing the corresponding input. In cases
+ ;; where an argument isn't meant to be differentiable its corresponding
+ ;; element in the list should be `#f'.
+ (let* ((inputs (cons input more))
+ (naked-inputs
+ (map
+ (lambda (x)
+ (if (internal? x)
+ (internal-forward x)
+ x))
+ inputs))
+ (output (apply function naked-inputs)))
+ (if (*grad*)
+ (make-internal
+ (or
+ (fold
+ (lambda (Jx x prev)
+ (if (not (internal? x))
+ prev
+ ((if (not prev)
+ identity
+ (lambda (x) ((extend +) prev x)))
+ (dot Jx
+ (internal-jacobian x)
+ (rank-of (internal-forward x))))))
+ #f
+ (apply jacobian-generator naked-inputs)
+ inputs)
+ (zeros-out:in output (grad-input)))
+ output)
+ output)))
+
+(define (i:identity x)
+ "Differentiable identity."
+ (differentiable-wrapper
+ (lambda (x) (list (mirror one x)))
+ identity
+ x))
+
+(define (i:exp x)
+ "Differentiable exponential."
+ (differentiable-wrapper
+ (lambda (x) (list (mirror exp x x)))
+ (extend exp)
+ x))
+
+(define (i:* x y)
+ "Differentiable element-wise multiplication."
+ (differentiable-wrapper
+ (lambda (x y)
+ (list
+ (cond
+ ((and (number? y) (number? x))
+ y)
+ ((and (number? y) (array? x))
+ (mirror (lambda _ y) x))
+ ((and (array? y) (number? x))
+ y)
+ (else
+ (mirror identity x y)))
+ (cond
+ ((and (number? x) (number? y))
+ x)
+ ((and (number? x) (array? y))
+ (mirror (lambda _ x) y))
+ ((and (array? x) (number? y))
+ x)
+ (else
+ (mirror identity y x)))))
+ (extend *)
+ x y))
+
+(define (i:- x y)
+ "Differentiable element-wise subtraction."
+ (differentiable-wrapper
+ (lambda (x y)
+ (list
+ (cond
+ ((and (number? y) (number? x))
+ +1)
+ ((and (number? y) (array? x))
+ (mirror one x))
+ ((and (array? y) (number? x))
+ (apply make-array 1 (array-dimensions y)))
+ (else
+ (mirror one x)))
+ (cond
+ ((and (number? x) (number? y))
+ -1)
+ ((and (number? x) (array? y))
+ (mirror (lambda _ -1) y))
+ ((and (array? x) (number? y))
+ (apply make-array -1 (array-dimensions x)))
+ (else
+ (mirror (lambda _ -1) y)))))
+ (extend -)
+ x y))
+
+(define (i:max x y)
+ "Differentiable element-wise maximum."
+ (define (dmax x y)
+ (cond
+ ((> x y)
+ 1)
+ ((= x y)
+ 1/2)
+ (else
+ 0)))
+ (differentiable-wrapper
+ (lambda (x y)
+ (list
+ (cond
+ ((and (number? y) (number? x))
+ (dmax x y))
+ ((and (number? y) (array? x))
+ (mirror (lambda (x) (dmax x y)) x x))
+ ((and (array? y) (number? x))
+ (array-map (lambda (y) (dmax x y)) y))
+ (else
+ (mirror (lambda (x y) (dmax x y)) x x y)))
+ (cond
+ ((and (number? x) (number? y))
+ (dmax y x))
+ ((and (number? x) (array? y))
+ (mirror (lambda (y) (dmax y x)) y y))
+ ((and (array? x) (number? y))
+ (array-map (lambda (x) (dmax y x)) x))
+ (else
+ (mirror (lambda (y x) (dmax y x)) y y x)))))
+ (extend max)
+ x y))
+
+(define (mean x)
+ "Differentiable mean on arrays."
+ (differentiable-wrapper
+ (lambda (x)
+ (let ((dims (array-dimensions x)))
+ (list
+ (apply
+ make-array
+ (/ 1 (apply * dims))
+ dims))))
+ (lambda (x)
+ (let ((sum 0)
+ (count 0))
+ (array-for-each
+ (lambda (x)
+ (set! sum (+ sum x))
+ (set! count (1+ count)))
+ x)
+ (/ sum count)))
+ x))
+
+(define (amap2 f x y)
+ (define (unbox-with proc x)
+ (ifn (internal? x)
+ x
+ (proc x)))
+ (define (boxed-ref x i)
+ (ifn (internal? x)
+ (array-cell-ref x i)
+ (make-internal (array-cell-ref (internal-jacobian x)
+ i)
+ (array-cell-ref (internal-forward x)
+ i))))
+ (define (boxed-fi f i x y)
+ (f (boxed-ref x i)
+ (boxed-ref y i)))
+ (define (dims-of x)
+ (if (number? x)
+ '()
+ (array-dimensions x)))
+ (if (internal? f)
+ (error "unsupported application of `amap2' to internal object" f)
+ (let ((bs (first (array-dimensions (unbox-with internal-forward x))))
+ (f0 (boxed-fi f 0 x y)))
+ (ifn (internal? f0)
+ (let ((fwd (apply make-array 0 bs (dims-of f0))))
+ (for-indices-in-range
+ (lambda (i)
+ (array-cell-set! fwd (boxed-fi f i x y) i))
+ '(0) (list bs))
+ fwd)
+ (let ((jac (apply make-array 0 bs (dims-of (internal-jacobian f0))))
+ (fwd (apply make-array 0 bs (dims-of (internal-forward f0)))))
+ (for-indices-in-range
+ (lambda (i)
+ (let ((fi (boxed-fi f i x y)))
+ (array-cell-set! jac (internal-jacobian fi) i)
+ (array-cell-set! fwd (internal-forward fi) i)))
+ '(0) (list bs))
+ (make-internal jac fwd))))))
+
+(define (adot x y n)
+ (differentiable-wrapper
+ (lambda (x y n)
+ (let ((bound-dims (contracted-dims x y n)))
+ (list
+ (let ((r (apply make-array 0 (append bound-dims (array-dimensions x)))))
+ (for-indices-in-range
+ (lambda free-indices
+ (let ((free-indices-x (take free-indices (- (array-rank x) n)))
+ (free-indices-y (drop free-indices (- (array-rank x) n))))
+ (for-indices-in-range
+ (lambda bound-indices
+ (apply
+ array-set!
+ r
+ (apply array-ref y (append bound-indices free-indices-y))
+ (append free-indices free-indices-x bound-indices)))
+ (list-zeros n)
+ (take (array-dimensions y) n))))
+ (list-zeros (length bound-dims))
+ bound-dims)
+ r)
+ (let ((r (apply make-array 0 (append bound-dims (array-dimensions y)))))
+ (for-indices-in-range
+ (lambda free-indices
+ (let ((free-indices-x (take free-indices (- (array-rank x) n)))
+ (free-indices-y (drop free-indices (- (array-rank x) n))))
+ (for-indices-in-range
+ (lambda bound-indices
+ (apply
+ array-set!
+ r
+ (apply array-ref x (append free-indices-x bound-indices))
+ (append free-indices bound-indices free-indices-y)))
+ (list-zeros n)
+ (take (array-dimensions y) n))))
+ (list-zeros (length bound-dims))
+ bound-dims)
+ r)
+ #f)))
+ contract-arrays x y n))