--- /dev/null
+(define-module (vouivre autodiff tests)
+ #:use-module ((vouivre autodiff) #:prefix v:)
+ #:use-module (ice-9 receive)
+ #:use-module (srfi srfi-1)
+ #:use-module (srfi srfi-64)
+ #:use-module (vouivre misc)
+ #:export
+ (apply-diff
+ a~
+ const-generator
+ differentiable-func-generator
+ lambda-const-call
+ ndiff
+ n~
+ random-array
+ random-array-shape
+ random-func1
+ random-func2
+ random-func2-rank&dims>0
+ random-input
+ random-list-element
+ random-non-empty-array
+ random-shape
+ random-shared
+ random-shared-array-rank&dims>0
+ random-shared-contractible
+ with-generators
+ ~))
+
+(define f1s (list v:abs v:cos v:exp v:identity v:sin))
+(define f2s (list v:+ v:- v:* v:max v:min))
+
+(define (with-generators% generators equal proc1 proc2 . more)
+ "Check that all procedures return the same value according to `equal' when
+evaluated on arguments produced by the generators (the number of generators
+being the number of arguments to each procedure."
+ (let ((times 100)
+ (procs (cons proc1 (cons proc2 more))))
+ (call/cc
+ (lambda (break)
+ (do ((i 0 (1+ i)))
+ ((= i times) #t)
+ (let ((zs (map-in-order (lambda (g) (g)) generators)))
+ (with-exception-handler
+ (lambda (e)
+ (break #f zs))
+ (lambda ()
+ (let* ((rs (map (lambda (f) (apply f zs)) procs))
+ (head (car rs)))
+ (unless (every (lambda (x) (equal x head))
+ (cdr rs))
+ (break #f zs rs))))
+ #:unwind? #t)))))))
+
+(define-syntax-rule (with-generators (g1 g2 ...) equal expected given more ...)
+ (with-generators% (list g1 g2 ...) equal expected given more ...))
+
+(define (lambda-const-call f . consts)
+ (lambda _
+ (apply f consts)))
+
+(define* (random-array-shape
+ #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5))
+ (list-tabulate (+ min-rank (random (- max-rank min-rank)))
+ (lambda _ (+ min-dim (random (- max-dim min-dim))))))
+
+(define (random-shape)
+ (if (= 0 (random 2))
+ 0
+ (random-array-shape)))
+
+(define* (random-array #:optional shape)
+ (apply produce-typed-array
+ (lambda _ (random:uniform))
+ v:*atype* (or shape (random-array-shape))))
+
+(define (random-non-empty-array)
+ "Random array of at least one element."
+ (random-array (random-array-shape 0 5 1 5)))
+
+(define* (random-input #:optional shape)
+ (let ((shape (or shape (random-shape))))
+ (if (eq? 0 shape)
+ (random:uniform)
+ (random-array shape))))
+
+(define (random-shared)
+ (let ((shape (random-shape)))
+ (values
+ (lambda ()
+ (random-input shape))
+ (lambda ()
+ (let ((x (random-input
+ (random-list-element
+ (list 0 (if (list? shape)
+ shape
+ (random-shape)))))))
+ (set! shape (random-shape))
+ x)))))
+
+(define (random-shared-array-rank&dims>0)
+ (let ((shape (random-array-shape 1 5 1 5)))
+ (values
+ (lambda ()
+ (random-array shape))
+ (lambda ()
+ (let ((x (random-array shape)))
+ (set! shape (random-array-shape 1 5 1 5))
+ x)))))
+
+(define (random-list-element lst)
+ (list-ref lst (random (length lst))))
+
+(define (const-generator generator)
+ (lambda ()
+ generator))
+
+(define (differentiable-func-generator lst . input-generators)
+ (lambda ()
+ (random-list-element
+ (cons
+ (apply
+ lambda-const-call
+ (random-list-element lst)
+ (map (lambda (g) (g))
+ input-generators))
+ lst))))
+
+(define random-func1
+ (differentiable-func-generator f1s random-input))
+(define random-func2
+ (receive (gx gy) (random-shared)
+ (differentiable-func-generator f2s gx gy)))
+
+(define* (n~ x y #:optional (error 1e-4))
+ (and
+ (>= y (- x error))
+ (<= y (+ x error))))
+
+(define* (a~ x y #:optional (error 1e-4))
+ (and
+ (equal? (array-dimensions x)
+ (array-dimensions y))
+ (call/cc
+ (lambda (break)
+ (array-for-each
+ (lambda (x y)
+ (unless (~ x y error)
+ (break #f)))
+ x y)
+ #t))))
+
+(define* (~ x y #:optional (error 1e-4))
+ (cond
+ ((and (number? x) (number? y))
+ (n~ x y error))
+ ((and (array? x) (array? y))
+ (a~ x y error))
+ (else #f)))
+
+(define* (ndiff f #:optional (axis 0) (step 1e-6))
+ "Differentiation using a numerical centered difference approximation."
+ (define (axis-add xs dh . indices)
+ "Add `dh' to the number or array at the given `axis' of `xs',
+and, when it's an array, at the given index."
+ (map-indexed
+ (lambda (x i)
+ (ifn (= i axis)
+ x
+ (if (number? x)
+ (+ x dh)
+ (array-map-indexed
+ (lambda (x . indices_)
+ (ifn (equal? indices indices_)
+ x
+ (+ x dh)))
+ x))))
+ xs))
+ (lambda xs
+ ;; We need the output shape and the input shape along the
+ ;; differentiated axis.
+ (let ((fxs (apply f xs))
+ (x (list-ref xs axis)))
+ (cond
+ ((and (number? fxs)
+ (number? x))
+ (/ (- (apply f (axis-add xs step))
+ (apply f (axis-add xs (- step))))
+ (* 2 step)))
+ ((and (number? fxs)
+ (array? x))
+ (apply
+ produce-typed-array
+ (lambda indices
+ (/ (- (apply f (apply axis-add xs step indices))
+ (apply f (apply axis-add xs (- step) indices)))
+ (* 2 step)))
+ v:*atype*
+ (array-dimensions x)))
+ ((and (array? fxs)
+ (number? x))
+ ((v:extend /)
+ ((v:extend -)
+ (apply f (axis-add xs step))
+ (apply f (axis-add xs (- step))))
+ (* 2 step)))
+ ((and (array? fxs)
+ (array? x))
+ (let ((a (apply
+ make-typed-array v:*atype* *unspecified*
+ (append (array-dimensions fxs)
+ (array-dimensions x)))))
+ (for-indices-in-range
+ (lambda indices-in
+ (let ((dfxs ((v:extend /)
+ ((v:extend -)
+ (apply f (apply axis-add xs step indices-in))
+ (apply f (apply axis-add xs (- step) indices-in)))
+ (* 2 step))))
+ (for-indices-in-range
+ (lambda indices-out
+ (apply
+ array-set!
+ a
+ (apply array-ref dfxs indices-out)
+ (append indices-out indices-in)))
+ (list-zeros (array-rank fxs))
+ (array-dimensions fxs))))
+ (list-zeros (array-rank x))
+ (array-dimensions x))
+ a))))))
+
+(define* (apply-diff differentiator #:optional (axis 0))
+ "Apply a differentiator (`ndiff', `fdiff', `rdiff') to a function and its
+arguments (this is a convenience function)."
+ (lambda (f . args)
+ (apply (differentiator f axis) args)))
+
+(test-begin "autodiff")
+
+;; not differentiating
+(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity))
+(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp))
+(test-assert
+ (receive (gx gy) (random-shared)
+ (with-generators (gx gy) ~ (v:extend *) v:*)))
+
+;; differentiation in one variable
+(test-assert
+ (with-generators
+ (random-func1 random-input)
+ ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff)))
+
+;; `v:mean' only takes non-empty arrays so we treat it separately
+(test-assert
+ (with-generators
+ ((differentiable-func-generator (list v:mean) random-non-empty-array)
+ random-non-empty-array)
+ ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff)))
+
+;; differentiation in two variables
+(test-assert
+ (receive (gx gy) (random-shared)
+ (with-generators
+ (random-func2 gx gy)
+ ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0))))
+(test-assert
+ (receive (gx gy) (random-shared)
+ (with-generators
+ (random-func2 gx gy)
+ ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1))))
+
+;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it
+;; separately
+(define random-func2-rank&dims>0
+ (receive (gx gy) (random-shared-array-rank&dims>0)
+ (differentiable-func-generator f2s gx gy)))
+(test-assert
+ (receive (gx gy) (random-shared-array-rank&dims>0)
+ (with-generators
+ ((const-generator v:amap2) random-func2-rank&dims>0 gx gy)
+ ;; NOTE: for `v:amap2' the differentiable axes are 1 and 2.
+ ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1))))
+(test-assert
+ (receive (gx gy) (random-shared-array-rank&dims>0)
+ (with-generators
+ ((const-generator v:amap2) random-func2-rank&dims>0 gx gy)
+ ~ (apply-diff ndiff 2) (apply-diff v:fdiff 2) (apply-diff v:rdiff 2))))
+(let* ((z #(1 2 3))
+ (f (lambda (a)
+ (v:amap2 (lambda (x y)
+ (v:* a a))
+ #(10 20 30)
+ #(40 50 60))))
+ (e ((ndiff f) z)))
+ (test-assert (~ e ((v:fdiff f) z)))
+ (test-assert (~ e ((v:rdiff f) z))))
+
+;; `v:adot'
+(define (random-shared-contractible)
+ "Returns three generators: the first two generate arrays that are contractible
+according to the number generated by the third one."
+ (let* ((n (random 5))
+ (sa (random-array-shape n))
+ (sb (append (reverse (take (reverse sa)
+ n))
+ (random-array-shape 0 (- 5 n)))))
+ (values
+ (lambda ()
+ (random-array sa))
+ (lambda ()
+ (random-array sb))
+ (lambda ()
+ (let ((tmp n))
+ (set! n (random 5))
+ (set! sa (random-array-shape n))
+ (set! sb (append (reverse (take (reverse sa)
+ n))
+ (random-array-shape 0 (- 5 n))))
+ tmp)))))
+(test-assert
+ (receive (gx gy gz) (random-shared-contractible)
+ (with-generators
+ ((const-generator v:adot) gx gy gz)
+ ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0))))
+(test-assert
+ (receive (gx gy gz) (random-shared-contractible)
+ (with-generators
+ ((const-generator v:adot) gx gy gz)
+ ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1))))
+
+;; let binding re-entry
+(test-assert
+ (with-generators
+ ((const-generator
+ (lambda (x)
+ (let ((c (v:maximum x)))
+ (v:+ c (v:- x c)))))
+ random-non-empty-array)
+ ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff)))
+
+;; chain rule
+(test-assert
+ (with-generators
+ (random-func1 random-func1 random-input)
+ ~
+ (lambda (f g x) ((ndiff (compose f g)) x))
+ (lambda (f g x) ((v:fdiff (compose f g)) x))
+ (lambda (f g x) ((v:rdiff (compose f g)) x))))
+
+(test-end "autodiff")
--- /dev/null
+(define-module (vouivre autodiff)
+ #:use-module (ice-9 receive)
+ #:use-module (srfi srfi-1)
+ #:use-module (srfi srfi-9)
+ #:use-module (vouivre misc)
+ #:use-module (vouivre promises)
+ #:export
+ (*atype*
+ adot
+ amap2
+ contract-arrays
+ differentiable-wrapper
+ dot
+ do-times
+ ewise1
+ ewise2
+ extend
+ fdiff
+ rdiff
+ make-batch
+ make-internal
+ maximum
+ mean
+ rank-of
+ sum)
+ #:replace
+ ((i:sqrt . sqrt)
+ (i:exp . exp)
+ (i:expt . expt)
+ (i:log . log)
+ (i:sin . sin)
+ (i:cos . cos)
+ (i:tan . tan)
+ (i:+ . +)
+ (i:- . -)
+ (i:* . *)
+ (i:/ . /)
+ (i:max . max)
+ (i:min . min)
+ (i:abs . abs)
+ (i:identity . identity)
+ (i:array-ref . array-ref)
+ (i:array-cell-ref . array-cell-ref))
+ #:re-export
+ (fold
+ reduce))
+
+;;;; array utilities
+
+(define (rel->abs indices dimensions)
+ (let rec ((s 0)
+ (p 1)
+ (is (reverse indices))
+ (ds (reverse dimensions)))
+ (if (null? is)
+ s
+ (rec (+ s (* p (car is)))
+ (* p (car ds))
+ (cdr is)
+ (cdr ds)))))
+
+(define(do-times n proc)
+ (let rec ((i 0))
+ (unless (= i n)
+ (proc i)
+ (rec (1+ i)))))
+
+(define (contract-arrays a b n)
+ (let* ((dims-a (array-dimensions a))
+ (dims-b (array-dimensions b))
+ (free-dims-a (take dims-a (- (array-rank a) n)))
+ (free-dims-b (drop dims-b n))
+ (bound-dims (take dims-b n))
+ (n-free-dims-a (apply * free-dims-a))
+ (n-free-dims-b (apply * free-dims-b))
+ (n-bound-dims (apply * bound-dims))
+ (s 0)
+ (r (apply make-typed-array *atype* *unspecified* (append free-dims-a
+ free-dims-b)))
+ (ac (array-contents a))
+ (bc (array-contents b))
+ (rc (array-contents r)))
+ (do-times
+ n-free-dims-a
+ (lambda (i)
+ (let ((i-k (* n-bound-dims i))
+ (i-j (* n-free-dims-b i)))
+ (do-times
+ n-free-dims-b
+ (lambda (j)
+ (set! s 0)
+ (do-times
+ n-bound-dims
+ (lambda (k)
+ (set! s (+ s (* (array-ref ac (+ i-k k))
+ (array-ref bc (+ (* n-free-dims-b k) j)))))))
+ (array-set! rc s (+ i-j j)))))))
+ r))
+
+;;;; utilities that work on both numbers and arrays
+
+(define (extend f)
+ "Extend a function of one or more scalars to apply to numbers/arrays
+element-wise. All arrays must have the same dimension."
+ (define (apply-elemwise f indices args)
+ (apply f (map (lambda (x)
+ (if (number? x)
+ x
+ (apply array-ref x indices)))
+ args)))
+ (lambda xs
+ (if-let (x (find array? xs))
+ (apply
+ produce-typed-array
+ (lambda is
+ (apply-elemwise f is xs))
+ *atype*
+ (array-dimensions x))
+ (apply f xs))))
+
+(define (dot x y n)
+ (cond
+ ((and (number? x) (number? y))
+ (* x y))
+ ((and (array? x) (array? y))
+ (contract-arrays x y n))
+ ((and (array? x) (number? y))
+ ((extend *) x y))
+ ((and (number? x) (array? y))
+ ((extend *) x y))
+ (else (error "can't dot because of invalid types or ranks" x y n))))
+
+(define (rank-of x)
+ (if (number? x)
+ 0
+ (array-rank x)))
+
+;;;; differentiation
+
+(define-record-type internal
+ (make-internal forward jacobian)
+ internal?
+ (forward internal-forward)
+ (jacobian internal-jacobian))
+
+;;(define *atype* 'f32)
+(define *atype* #t)
+(define *differentiation-mode* (make-parameter #f))
+(define *n-y-dims* (make-parameter #f))
+(define *j* (make-parameter #f))
+
+(define-syntax-rule (w/j val body ...)
+ (parameterize ((*j* val))
+ body ...))
+
+(define (wrap axis)
+ (lambda (x i)
+ (if (= i axis)
+ (make-internal x 'input)
+ x)))
+
+(define (unwrap-fwd x)
+ (if (internal? x)
+ (internal-forward x)
+ x))
+
+(define (unwrap-jac x)
+ (if (internal? x)
+ (internal-jacobian x)
+ x))
+
+(define (dims-of x)
+ (if (number? x)
+ '()
+ (array-dimensions x)))
+
+(define (add dst-buf src-buf n-dims)
+ (do-times
+ n-dims
+ (lambda (i)
+ (array-set!
+ dst-buf
+ (+ (array-ref dst-buf i)
+ (array-ref src-buf i))
+ i)))
+ dst-buf)
+
+(define (movg dst-buf n-dst-dims generator naked-inputs data j)
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (array-set!
+ dst-buf
+ (apply generator naked-inputs i j data)
+ i)))
+ dst-buf)
+
+(define (addg dst-buf n-dst-dims generator naked-inputs data j)
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (array-set!
+ dst-buf
+ (+ (array-ref dst-buf i)
+ (apply generator naked-inputs i j data))
+ i)))
+ dst-buf)
+
+(define (movc dst-buf n-dst-dims src-buf n-src-dims
+ generator naked-inputs data)
+ "Contract the Jacobian column produced by the generator with the source buffer
+storing the result in the destination buffer."
+ (let ((s 0))
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (set! s 0)
+ (do-times
+ n-src-dims
+ (lambda (k)
+ (set! s (+ s (* (apply generator naked-inputs i k data)
+ (array-ref src-buf k))))))
+ (array-set! dst-buf s i))))
+ dst-buf)
+
+(define (addc dst-buf n-dst-dims src-buf n-src-dims
+ generator naked-inputs data)
+ "Contract the Jacobian column produced by the generator with the source buffer
+adding the result to the destination buffer."
+ (let ((s 0))
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (set! s (array-ref dst-buf i))
+ (do-times
+ n-src-dims
+ (lambda (k)
+ (set! s (+ s (* (apply generator naked-inputs i k data)
+ (array-ref src-buf k))))))
+ (array-set! dst-buf s i))))
+ dst-buf)
+
+(define (transpose-generator generator)
+ (lambda (xs i j . data)
+ (apply generator xs j i data)))
+
+(define* (fdiff f #:optional (axis 0))
+ (lambda xs
+ (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'fwd)
+ ((@@ (vouivre autodiff) *promises*) (cons '() #f)))
+ (let* ((internal (apply f (map-indexed (wrap axis) xs)))
+ (fx (internal-forward internal))
+ (y (list-ref xs axis)) ; variable to differentiate w.r.t
+ (pre-Jx (internal-jacobian internal))
+ (Jx (cond
+ ;; TODO: implement 'input case and test 'zero and 'input
+ ((eq? pre-Jx 'zero)
+ (lambda (j)
+ (lambda (i)
+ 0)))
+ ((eq? pre-Jx 'input)
+ (error "TBD."))
+ (else
+ (lambda (j)
+ (reset-promises (car (*promises*)))
+ (let ((column-jac (w/j j (force pre-Jx))))
+ (lambda (i)
+ (array-ref column-jac i))))))))
+ (cond
+ ((and (number? fx) (number? y))
+ ((Jx 0) 0))
+ ((and (number? fx) (array? y))
+ (let* ((y-dims (array-dimensions y))
+ (a (apply make-array *unspecified* y-dims))
+ (ac (array-contents a)))
+ (do-times
+ (apply * y-dims)
+ (lambda (j)
+ (array-set! ac ((Jx j) 0)
+ j)))
+ a))
+ ((and (array? fx) (number? y))
+ (let* ((fx-dims (array-dimensions fx))
+ (a (apply make-array *unspecified* fx-dims))
+ (ac (array-contents a))
+ (Jx (Jx 0)))
+ (do-times
+ (apply * fx-dims)
+ (lambda (i)
+ (array-set! ac (Jx i)
+ i)))
+ a))
+ (else
+ (let* ((fx-dims (array-dimensions fx))
+ (y-dims (array-dimensions y))
+ (n-fx-dims (apply * fx-dims))
+ (n-y-dims (apply * y-dims))
+ (a (apply make-array *unspecified* (append fx-dims y-dims)))
+ (ac (array-contents a)))
+ (do-times
+ n-y-dims
+ (lambda (j)
+ (let ((Jx (Jx j)))
+ (do-times
+ n-fx-dims
+ (lambda (i)
+ (array-set! ac (Jx i)
+ (+ j (* n-y-dims i))))))))
+ a)))))))
+
+(define* (rdiff f #:optional (axis 0))
+ (lambda xs
+ (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'rev)
+ ((@@ (vouivre autodiff) *promises*) (cons '() #f)))
+ (let* ((internal (apply f (map-indexed (wrap axis) xs)))
+ (fx (internal-forward internal))
+ (y (list-ref xs axis)) ; variable to differentiate w.r.t
+ (y-dims (dims-of y))
+ (pre-Jx (internal-jacobian internal))
+ (Jx (cond
+ ;; TODO: implement 'input case and test 'zero and 'input
+ ((eq? pre-Jx 'zero)
+ (lambda (i)
+ (lambda (j)
+ 0)))
+ ((eq? pre-Jx 'input)
+ (error "TBD."))
+ (else
+ (let ((pre-Jx (pre-Jx #f)))
+ (lambda (i)
+ (let ((row-jac (pre-Jx i)))
+ (lambda (j)
+ (array-ref row-jac j)))))))))
+ (parameterize ((*n-y-dims* (apply * y-dims)))
+ (cond
+ ((and (number? fx) (number? y))
+ ((Jx 0) 0))
+ ((and (number? fx) (array? y))
+ (let* ((a (apply make-array *unspecified* y-dims))
+ (ac (array-contents a))
+ (Jx (Jx 0)))
+ (do-times
+ (*n-y-dims*)
+ (lambda (j)
+ (array-set! ac (Jx j)
+ j)))
+ a))
+ ((and (array? fx) (number? y))
+ (let* ((fx-dims (array-dimensions fx))
+ (a (apply make-array *unspecified* fx-dims))
+ (ac (array-contents a)))
+ (do-times
+ (apply * fx-dims)
+ (lambda (i)
+ (array-set! ac ((Jx i) 0)
+ i)))
+ a))
+ (else
+ (let* ((fx-dims (array-dimensions fx))
+ (n-fx-dims (apply * fx-dims))
+ (a (apply make-array *unspecified* (append fx-dims y-dims)))
+ (ac (array-contents a)))
+ (do-times
+ n-fx-dims
+ (lambda (i)
+ (let ((Jx (Jx i)))
+ (do-times
+ (*n-y-dims*)
+ (lambda (j)
+ (array-set! ac (Jx j)
+ (+ j (* (*n-y-dims*) i))))))))
+ a))))))))
+
+;; In the comment that follows:
+;;
+;; `n' is the number of arguments to `proc'.
+;; `generators is not a `Vec' but a `List' we only use the former to illustrate
+;; its length.
+;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
+;; `I' is the type of multi-indices indexing the output of `function'.
+;; `J' is the type of multi-indices indexing the input array being differentiated.
+;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J').
+;; `Array I' is the type of arrays indexed by multi-indices of `I'.
+;; `[X]' means that `X' is boxed in an internal as when returned by
+;; `differentiable-wrapper' with the array being `X' and the promise that
+;; given a |J| we will get the change of `X' with a change of the
+;; the differentiated argument at multi-index `J'.
+;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number))
+;; (→ X1 ... Xn (Array I))*
+;; [X1] ... [Xn]
+;; (Internal (Array I) (Promise |J| (Array |I|)))))
+;;
+;; (*) We extend this definition to allow `proc' to be a list of procedures
+;; the head of which is as described above and the remaining elements
+;; are procedures of the same arguments but returning values that are
+;; then fed as extra data to the generators.
+;;
+;; NOTE: In cases where an argument isn't meant to be differentiable its
+;; corresponding generator should be `#f'.
+(define (differentiable-wrapper generators proc* arg . more)
+ (define (precompute-data naked-args)
+ (if (procedure? proc*)
+ '()
+ (map (lambda (g)
+ (apply g naked-args))
+ (cdr proc*))))
+ (let* ((args (cons arg more))
+ (proc (if (procedure? proc*)
+ proc*
+ (car proc*)))
+ (naked-args (map unwrap-fwd args))
+ (out (apply proc naked-args)))
+ (case (*differentiation-mode*)
+ ((#f)
+ out)
+ ((fwd)
+ (let* ((data (precompute-data naked-args))
+ (n-out-dims (apply * (dims-of out)))
+ (buf (make-array *unspecified* n-out-dims)))
+ (make-internal
+ out
+ (fold
+ (lambda (generator arg prev)
+ (if (or (not (internal? arg))
+ (eq? 'zero (internal-jacobian arg)))
+ prev
+ (let ((Jx (internal-jacobian arg))
+ (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
+ (if (eq? Jx 'input)
+ (if (eq? prev 'zero)
+ (delay
+ (movg buf n-out-dims
+ generator naked-args data (*j*)))
+ (delay
+ (addg (force prev) n-out-dims
+ generator naked-args data (*j*))))
+ (if (eq? prev 'zero)
+ (delay
+ (movc buf n-out-dims (force Jx) n-fwd-dims
+ generator naked-args data))
+ (delay
+ (addc (force prev) n-out-dims
+ (force Jx) n-fwd-dims
+ generator naked-args data)))))))
+ 'zero generators args))))
+ ((rev)
+ (let ((data (precompute-data naked-args))
+ (n-out-dims (apply * (dims-of out))))
+ (make-internal
+ out
+ (fold
+ (lambda (generator arg prev)
+ (let ((generator (transpose-generator generator)))
+ (if (or (not (internal? arg))
+ (eq? 'zero (internal-jacobian arg)))
+ prev
+ (let* ((Jx (internal-jacobian arg))
+ (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
+ (if (eq? Jx 'input)
+ (if (eq? prev 'zero)
+ (lambda (buf?)
+ (let ((dst-buf (make-array *unspecified*
+ n-fwd-dims)))
+ (if buf?
+ (lambda (buf)
+ (movc dst-buf n-fwd-dims
+ buf n-out-dims
+ generator naked-args data))
+ (lambda (i)
+ (movg dst-buf n-fwd-dims
+ generator naked-args data
+ i)))))
+ (lambda (buf?)
+ (let ((prev (prev buf?)))
+ (if buf?
+ (lambda (buf)
+ (addc (prev buf) n-fwd-dims
+ buf n-out-dims
+ generator naked-args data))
+ (lambda (i)
+ (addg (prev i) n-fwd-dims
+ generator naked-args data
+ i))))))
+ (if (eq? prev 'zero)
+ (lambda (buf?)
+ (let ((Jx (Jx #t))
+ (dst-buf (make-array *unspecified*
+ n-fwd-dims)))
+ (if buf?
+ (lambda (buf)
+ (Jx
+ (movc dst-buf n-fwd-dims buf
+ n-out-dims
+ generator naked-args data)))
+ (lambda (i)
+ (Jx
+ (movg dst-buf n-fwd-dims
+ generator naked-args data
+ i))))))
+ (lambda (buf?)
+ (let ((prev (prev buf?))
+ (Jx (Jx #t))
+ (dst-buf (make-array *unspecified*
+ n-fwd-dims)))
+ (if buf?
+ (lambda (buf)
+ (add (prev buf)
+ (Jx
+ (movc dst-buf n-fwd-dims
+ buf n-out-dims
+ generator naked-args data))
+ (*n-y-dims*)))
+ (lambda (i)
+ (add (prev i)
+ (Jx
+ (movg dst-buf n-fwd-dims
+ generator naked-args data
+ i))
+ (*n-y-dims*))))))))))))
+ 'zero generators args)))))))
+
+(define (ewise1 f)
+ (lambda (xs i j)
+ (let ((x (car xs)))
+ (if (number? x)
+ (f x)
+ (ifn (= i j)
+ 0
+ (f (array-ref (array-contents x)
+ j)))))))
+
+(define (ewise2 proc axis)
+ (lambda (xs i j)
+ (let ((x (car xs))
+ (y (cadr xs)))
+ (cond
+ ((and (number? x) (number? y))
+ (proc x y))
+ ((and (number? x) (array? y))
+ (if (= axis 0)
+ (proc x (array-ref (array-contents y)
+ i))
+ (ifn (= i j)
+ 0
+ (proc x (array-ref (array-contents y)
+ j)))))
+ ((and (array? x) (number? y))
+ (if (= axis 1)
+ (proc (array-ref (array-contents x)
+ i)
+ y)
+ (ifn (= i j)
+ 0
+ (proc (array-ref (array-contents x)
+ j)
+ y))))
+ (else
+ (ifn (= i j)
+ 0
+ (proc (array-ref (array-contents x)
+ j)
+ (array-ref (array-contents y)
+ j))))))))
+
+(define (i:identity x)
+ "Differentiable identity."
+ (differentiable-wrapper
+ (list (ewise1 (lambda _ 1)))
+ identity
+ x))
+
+(define (i:sqrt x)
+ "Differentiable square root."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (/ 1 2 (sqrt x)))))
+ (extend sqrt)
+ x))
+
+(define (i:exp x)
+ "Differentiable exponential."
+ (differentiable-wrapper
+ (list (ewise1 exp))
+ (extend exp)
+ x))
+
+(define (i:expt x y)
+ "Differentiable power."
+ (differentiable-wrapper
+ (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0)
+ (ewise2 (lambda (x y) (* (expt x y) (log x))) 1))
+ (extend expt)
+ x y))
+
+(define (i:log x)
+ "Differentiable logarithm."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (/ x))))
+ (extend log)
+ x))
+
+(define (i:sin x)
+ "Differentiable sine."
+ (differentiable-wrapper
+ (list (ewise1 cos))
+ (extend sin)
+ x))
+
+(define (i:cos x)
+ "Differentiable cosine."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (- (sin x)))))
+ (extend cos)
+ x))
+
+(define (i:tan x)
+ "Differentiable tangent."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (/ (expt (cos x) 2)))))
+ (extend tan)
+ x))
+
+(define (i:+ x y)
+ "Differentiable element-wise addition."
+ (differentiable-wrapper
+ (list
+ (ewise2 (lambda _ +1) 0)
+ (ewise2 (lambda _ +1) 1))
+ (extend +)
+ x y))
+
+(define (i:- x y)
+ "Differentiable element-wise subtraction."
+ (differentiable-wrapper
+ (list
+ (ewise2 (lambda _ +1) 0)
+ (ewise2 (lambda _ -1) 1))
+ (extend -)
+ x y))
+
+(define (i:* x y)
+ "Differentiable element-wise multiplication."
+ (differentiable-wrapper
+ (list
+ (ewise2 (lambda (x y) y) 0)
+ (ewise2 (lambda (x y) x) 1))
+ (extend *)
+ x y))
+
+(define (i:/ x y)
+ "Differentiable element-wise division."
+ (differentiable-wrapper
+ (list
+ (ewise2 (lambda (x y) (/ y)) 0)
+ (ewise2 (lambda (x y) (- (/ x y y))) 1))
+ (extend /)
+ x y))
+
+(define (i:max x y)
+ "Differentiable element-wise maximum."
+ (define (dmax x y)
+ (cond
+ ((> x y)
+ 1)
+ ((= x y)
+ 1/2)
+ (else
+ 0)))
+ (differentiable-wrapper
+ (list
+ (ewise2 dmax 0)
+ (ewise2 (flip dmax) 1))
+ (extend max)
+ x y))
+
+(define (i:min x y)
+ "Differentiable element-wise minimum."
+ (define (dmin x y)
+ (cond
+ ((< x y)
+ 1)
+ ((= x y)
+ 1/2)
+ (else
+ 0)))
+ (differentiable-wrapper
+ (list
+ (ewise2 dmin 0)
+ (ewise2 (flip dmin) 1))
+ (extend min)
+ x y))
+
+(define (i:abs x)
+ "Differentiable absolute."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x)
+ (cond ((> x 0)
+ +1)
+ ((= x 0)
+ 1/2)
+ ((< x 0)
+ -1)))))
+ (extend abs)
+ x))
+
+(define (mean x)
+ "Differentiable mean on arrays."
+ (differentiable-wrapper
+ (list
+ (lambda (xs i j one-over-n)
+ one-over-n))
+ (let ((n 0))
+ (list
+ (lambda (x)
+ (let ((sum 0))
+ (array-for-each
+ (lambda (x)
+ (set! sum (+ sum x))
+ (set! n (1+ n)))
+ x)
+ (/ sum n)))
+ (lambda _ (/ n))))
+ x))
+
+(define (i:array-ref x . indices)
+ "Differentiable array-ref w.r.t `x'."
+ (apply
+ differentiable-wrapper
+ (cons
+ (lambda (xs i j abs-index)
+ (if (= j abs-index)
+ 1
+ 0))
+ (map not indices))
+ (list
+ array-ref
+ (lambda (x . indices)
+ (rel->abs indices (array-dimensions x))))
+ x indices))
+
+(define (i:array-cell-ref x . indices)
+ (apply
+ differentiable-wrapper
+ (cons
+ (lambda (xs i j abs-index n-rst-dims)
+ (receive (j-ref j-rst) (euclidean/ j n-rst-dims)
+ (if (and (= j-ref abs-index)
+ (= j-rst i))
+ 1
+ 0)))
+ (map not indices))
+ (list
+ array-cell-ref
+ (lambda (x . indices)
+ (rel->abs indices (take (array-dimensions x)
+ (length indices))))
+ (lambda (x . indices)
+ (apply * (drop (array-dimensions x)
+ (length indices)))))
+ x indices))
+
+(define (make-batch elem . more)
+ (let ((batch-size (1+ (length more))))
+ (apply
+ differentiable-wrapper
+ (list-tabulate
+ batch-size
+ (lambda (b)
+ (lambda (xs i j n-rest-dims)
+ (receive (i-batch i-rest) (euclidean/ i n-rest-dims)
+ (if (and (= i-batch b)
+ (= i-rest j))
+ 1
+ 0
+ )))))
+ (list
+ (lambda (elem . more)
+ (let ((a (apply make-typed-array *atype* *unspecified* batch-size
+ (dims-of elem))))
+ (for-each
+ (lambda (x b)
+ (array-cell-set! a x b))
+ (cons elem more)
+ (list-tabulate batch-size identity))
+ a))
+ (lambda (elem . more)
+ (apply * (dims-of elem))))
+ elem more)))
+
+(define (maximum x)
+ "Differentiable maximum on arrays."
+ (differentiable-wrapper
+ (list
+ (lambda (xs i j max-index)
+ (if (= j max-index)
+ 1
+ 0)))
+ (let ((max-index 'TBD))
+ (list
+ (lambda (x)
+ (let ((m (- (inf)))
+ (i 0))
+ (array-for-each
+ (lambda (x)
+ (when (< m x)
+ (set! m x)
+ (set! max-index i))
+ (set! i (1+ i)))
+ x)
+ m))
+ (lambda _ max-index)))
+ x))
+
+(define (sum x)
+ "Differentiable sum on arrays."
+ (differentiable-wrapper
+ (list (lambda _ 1))
+ (lambda (x)
+ (let ((sum 0))
+ (array-for-each
+ (lambda (x)
+ (set! sum (+ sum x)))
+ x)
+ sum))
+ x))
+
+(define (adot x y n)
+ (differentiable-wrapper
+ (list
+ (lambda (xs i j n-free-dims-y n-bound-dims)
+ (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+ (receive (j-free j-bound) (euclidean/ j n-bound-dims)
+ (ifn (= i-x j-free)
+ 0
+ (array-ref (array-contents (cadr xs))
+ (+ i-y (* n-free-dims-y j-bound)))))))
+ (lambda (xs i j n-free-dims-y n-bound-dims)
+ (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+ (receive (j-bound j-free) (euclidean/ j n-free-dims-y)
+ (ifn (= i-y j-free)
+ 0
+ (array-ref (array-contents (car xs))
+ (+ j-bound (* n-bound-dims i-x)))))))
+ #f)
+ (list
+ contract-arrays
+ (lambda (x y n)
+ (apply * (drop (array-dimensions y)
+ n)))
+ (lambda (x y n)
+ (apply * (take (array-dimensions y)
+ n))))
+ x y n))
+
+(define (amap2 f x y)
+ (apply make-batch
+ (list-tabulate (car (dims-of (unwrap-fwd x)))
+ (lambda (b)
+ (f (i:array-cell-ref x b)
+ (i:array-cell-ref y b))))))
+++ /dev/null
-(define-module (vouivre grad tests)
- #:use-module ((vouivre grad) #:prefix v:)
- #:use-module (ice-9 receive)
- #:use-module (srfi srfi-1)
- #:use-module (srfi srfi-64)
- #:use-module (vouivre misc)
- #:export
- (apply-grad
- apply-grad-amap2
- a~
- differentiable-func-generator
- lambda-const-call
- ngrad
- n~
- random-array
- random-array-shape
- random-func1
- random-func2
- random-func2-rank&dims>0
- random-input
- random-list-element
- random-non-empty-array
- random-shape
- random-shared
- random-shared-array-rank&dims>0
- random-shared-contractible
- with-generators
- ~))
-
-(define f1s (list v:exp v:identity))
-(define f2s (list v:+ v:- v:* v:max v:min))
-
-(define-syntax-rule (with-generators (g1 g2 ...) equal expected given)
- (let ((times 100)
- (fx expected)
- (fy given)
- (generators (list g1 g2 ...)))
- (call/cc
- (lambda (break)
- (do ((i 0 (1+ i)))
- ((= i times) #t)
- (let ((zs (map-in-order (lambda (g) (g)) generators)))
- (with-exception-handler
- (lambda (e)
- (break #f zs))
- (lambda ()
- (let ((r1 (apply fx zs))
- (r2 (apply fy zs)))
- (unless (equal r1 r2)
- (break #f zs r1 r2))))
- #:unwind? #t)))))))
-
-(define (lambda-const-call f . consts)
- (lambda _
- (apply f consts)))
-
-(define* (random-array-shape
- #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5))
- (list-tabulate (+ min-rank (random (- max-rank min-rank)))
- (lambda _ (+ min-dim (random (- max-dim min-dim))))))
-
-(define (random-shape)
- (if (= 0 (random 2))
- 0
- (random-array-shape)))
-
-(define* (random-array #:optional shape)
- (apply produce-typed-array
- (lambda _ (random:uniform))
- v:*atype* (or shape (random-array-shape))))
-
-(define (random-non-empty-array)
- "Random array of at least one element."
- (random-array (random-array-shape 0 5 1 5)))
-
-(define* (random-input #:optional shape)
- (let ((shape (or shape (random-shape))))
- (if (eq? 0 shape)
- (random:uniform)
- (random-array shape))))
-
-(define (random-shared)
- (let ((shape (random-shape)))
- (values
- (lambda ()
- (random-input shape))
- (lambda ()
- (let ((x (random-input
- (random-list-element
- (list 0 (if (list? shape)
- shape
- (random-shape)))))))
- (set! shape (random-shape))
- x)))))
-
-(define (random-shared-array-rank&dims>0)
- (let ((shape (random-array-shape 1 5 1 5)))
- (values
- (lambda ()
- (random-array shape))
- (lambda ()
- (let ((x (random-array shape)))
- (set! shape (random-array-shape 1 5 1 5))
- x)))))
-
-(define (random-list-element lst)
- (list-ref lst (random (length lst))))
-
-(define (differentiable-func-generator lst . input-generators)
- (lambda ()
- (random-list-element
- (cons
- (apply
- lambda-const-call
- (random-list-element lst)
- (map (lambda (g) (g))
- input-generators))
- lst))))
-
-(define random-func1
- (differentiable-func-generator f1s random-input))
-(define random-func2
- (receive (gx gy) (random-shared)
- (differentiable-func-generator f2s gx gy)))
-
-(define* (n~ x y #:optional (error 1e-4))
- (and
- (>= y (- x error))
- (<= y (+ x error))))
-
-(define* (a~ x y #:optional (error 1e-4))
- (and
- (equal? (array-dimensions x)
- (array-dimensions y))
- (call/cc
- (lambda (break)
- (array-for-each
- (lambda (x y)
- (unless (~ x y error)
- (break #f)))
- x y)
- #t))))
-
-(define* (~ x y #:optional (error 1e-4))
- (cond
- ((and (number? x) (number? y))
- (n~ x y error))
- ((and (array? x) (array? y))
- (a~ x y error))
- (else #f)))
-
-(define* (ngrad f #:optional (axis 0) (step 1e-6))
- "Gradient using a numerical centered difference approximation."
- (define (axis-add xs dh . indices)
- "Add `dh' to the number or array at the given `axis' of `xs',
-and, when it's an array, at the given index."
- (map-indexed
- (lambda (x i)
- (ifn (= i axis)
- x
- (if (number? x)
- (+ x dh)
- (array-map-indexed
- (lambda (x . indices_)
- (ifn (equal? indices indices_)
- x
- (+ x dh)))
- x))))
- xs))
- (lambda xs
- ;; We need the output shape and the input shape along the
- ;; differentiated axis.
- (let ((fxs (apply f xs))
- (x (list-ref xs axis)))
- (cond
- ((and (number? fxs)
- (number? x))
- (/ (- (apply f (axis-add xs step))
- (apply f (axis-add xs (- step))))
- (* 2 step)))
- ((and (number? fxs)
- (array? x))
- (apply
- produce-typed-array
- (lambda indices
- (/ (- (apply f (apply axis-add xs step indices))
- (apply f (apply axis-add xs (- step) indices)))
- (* 2 step)))
- v:*atype*
- (array-dimensions x)))
- ((and (array? fxs)
- (number? x))
- ((v:extend /)
- ((v:extend -)
- (apply f (axis-add xs step))
- (apply f (axis-add xs (- step))))
- (* 2 step)))
- ((and (array? fxs)
- (array? x))
- (let ((a (apply
- make-typed-array v:*atype* *unspecified*
- (append (array-dimensions fxs)
- (array-dimensions x)))))
- (for-indices-in-range
- (lambda indices-in
- (let ((dfxs ((v:extend /)
- ((v:extend -)
- (apply f (apply axis-add xs step indices-in))
- (apply f (apply axis-add xs (- step) indices-in)))
- (* 2 step))))
- (for-indices-in-range
- (lambda indices-out
- (apply
- array-set!
- a
- (apply array-ref dfxs indices-out)
- (append indices-out indices-in)))
- (list-zeros (array-rank fxs))
- (array-dimensions fxs))))
- (list-zeros (array-rank x))
- (array-dimensions x))
- a))))))
-
-(test-begin "grad")
-
-;; extended operations
-(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity))
-(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp))
-(test-assert
- (receive (gx gy) (random-shared)
- (with-generators (gx gy) ~ (v:extend *) v:*)))
-
-(define* (apply-grad grad-func #:optional (axis 0))
- (lambda (f . args)
- (apply (grad-func f axis) args)))
-
-(test-assert
- (with-generators
- (random-func1 random-input)
- ~ (apply-grad ngrad) (apply-grad v:grad)))
-
-;; `v:mean' only takes non-empty arrays so we treat it separately
-(test-assert
- (with-generators
- ((differentiable-func-generator (list v:mean) random-non-empty-array)
- random-non-empty-array)
- ~ (apply-grad ngrad) (apply-grad v:grad)))
-
-(test-assert
- (receive (gx gy) (random-shared)
- (with-generators
- (random-func2 gx gy)
- ~ (apply-grad ngrad 0) (apply-grad v:grad 0))))
-(test-assert
- (receive (gx gy) (random-shared)
- (with-generators
- (random-func2 gx gy)
- ~ (apply-grad ngrad 1) (apply-grad v:grad 1))))
-
-;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it
-;; separately
-(define* (apply-grad-amap2 grad-func #:optional (axis 0))
- (lambda (f x y)
- ((grad-func
- (lambda (x y)
- (v:amap2 f x y))
- axis)
- x y)))
-(define random-func2-rank&dims>0
- (receive (gx gy) (random-shared-array-rank&dims>0)
- (differentiable-func-generator f2s gx gy)))
-(test-assert
- (receive (gx gy) (random-shared-array-rank&dims>0)
- (with-generators
- (random-func2-rank&dims>0 gx gy)
- ~ (apply-grad-amap2 ngrad 0) (apply-grad-amap2 v:grad 0))))
-(test-assert
- (receive (gx gy) (random-shared-array-rank&dims>0)
- (with-generators
- (random-func2-rank&dims>0 gx gy)
- ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1))))
-(test-assert
- (~ ((ngrad v:amap2 1) v:* #(1 2 3) #(10 20 30))
- ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30))))
-(test-assert
- (let ((x #(10 20 30))
- (y #(10 20 30)))
- (~ ((ngrad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))
- ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)))))
-
-;; `v:adot'
-(define (random-shared-contractible)
- "Returns three generators: the first two generate arrays that are contractible
-according to the number generated by the third one."
- (let* ((n (random 5))
- (sa (random-array-shape n))
- (sb (append (reverse (take (reverse sa)
- n))
- (random-array-shape 0 (- 5 n)))))
- (values
- (lambda ()
- (random-array sa))
- (lambda ()
- (random-array sb))
- (lambda ()
- (let ((tmp n))
- (set! n (random 5))
- (set! sa (random-array-shape n))
- (set! sb (append (reverse (take (reverse sa)
- n))
- (random-array-shape 0 (- 5 n))))
- tmp)))))
-(test-assert
- (receive (gx gy gz) (random-shared-contractible)
- (with-generators
- (gx gy gz)
- ~
- (lambda (a b n) ((ngrad v:adot 0) a b n))
- (lambda (a b n) ((v:grad v:adot 0) a b n)))))
-(test-assert
- (receive (gx gy gz) (random-shared-contractible)
- (with-generators
- (gx gy gz)
- ~
- (lambda (a b n) ((ngrad v:adot 1) a b n))
- (lambda (a b n) ((v:grad v:adot 1) a b n)))))
-
-;; chain rule
-(test-assert
- (with-generators
- (random-func1 random-func1 random-input)
- ~
- (lambda (f g x)
- (let* ((gx (g x))
- (r
- (v:dot ((v:grad f) gx)
- ((v:grad g) x)
- (v:rank-of gx))))
- (ifn (and (number? (f (g x)))
- (number? x)
- (array? r))
- r
- (array-ref r))))
- (lambda (f g x)
- ((v:grad (compose f g)) x))))
-
-(test-end "grad")
+++ /dev/null
-(define-module (vouivre grad)
- #:use-module (ice-9 receive)
- #:use-module (srfi srfi-1)
- #:use-module (srfi srfi-9)
- #:use-module (vouivre misc)
- #:use-module (vouivre promises)
- #:export
- (*atype*
- adot
- amap2
- contract-arrays
- differentiable-wrapper
- dot
- do-times
- ewise1
- ewise2
- extend
- grad
- make-batch
- make-internal
- maximum
- mean
- rank-of
- sum)
- #:replace
- ((i:sqrt . sqrt)
- (i:exp . exp)
- (i:expt . expt)
- (i:log . log)
- (i:sin . sin)
- (i:cos . cos)
- (i:tan . tan)
- (i:+ . +)
- (i:- . -)
- (i:* . *)
- (i:/ . /)
- (i:max . max)
- (i:min . min)
- (i:abs . abs)
- (i:identity . identity)
- (i:array-ref . array-ref)
- (i:array-cell-ref . array-cell-ref))
- #:re-export
- (fold
- reduce))
-
-;;;; array utilities
-
-(define (rel->abs indices dimensions)
- (let rec ((s 0)
- (p 1)
- (is (reverse indices))
- (ds (reverse dimensions)))
- (if (null? is)
- s
- (rec (+ s (* p (car is)))
- (* p (car ds))
- (cdr is)
- (cdr ds)))))
-
-(define(do-times n proc)
- (let rec ((i 0))
- (unless (= i n)
- (proc i)
- (rec (1+ i)))))
-
-(define (contract-arrays a b n)
- (let* ((dims-a (array-dimensions a))
- (dims-b (array-dimensions b))
- (free-dims-a (take dims-a (- (array-rank a) n)))
- (free-dims-b (drop dims-b n))
- (bound-dims (take dims-b n))
- (n-free-dims-a (apply * free-dims-a))
- (n-free-dims-b (apply * free-dims-b))
- (n-bound-dims (apply * bound-dims))
- (s 0)
- (r (apply make-typed-array *atype* *unspecified* (append free-dims-a
- free-dims-b)))
- (ac (array-contents a))
- (bc (array-contents b))
- (rc (array-contents r)))
- (do-times
- n-free-dims-a
- (lambda (i)
- (let ((i-k (* n-bound-dims i))
- (i-j (* n-free-dims-b i)))
- (do-times
- n-free-dims-b
- (lambda (j)
- (set! s 0)
- (do-times
- n-bound-dims
- (lambda (k)
- (set! s (+ s (* (array-ref ac (+ i-k k))
- (array-ref bc (+ (* n-free-dims-b k) j)))))))
- (array-set! rc s (+ i-j j)))))))
- r))
-
-;;;; utilities that work on both numbers and arrays
-
-(define (extend f)
- "Extend a function of one or more scalars to apply to numbers/arrays
-element-wise. All arrays must have the same dimension."
- (define (apply-elemwise f indices args)
- (apply f (map (lambda (x)
- (if (number? x)
- x
- (apply array-ref x indices)))
- args)))
- (lambda xs
- (if-let (x (find array? xs))
- (apply
- produce-typed-array
- (lambda is
- (apply-elemwise f is xs))
- *atype*
- (array-dimensions x))
- (apply f xs))))
-
-(define (dot x y n)
- (cond
- ((and (number? x) (number? y))
- (* x y))
- ((and (array? x) (array? y))
- (contract-arrays x y n))
- ((and (array? x) (number? y))
- ((extend *) x y))
- ((and (number? x) (array? y))
- ((extend *) x y))
- (else (error "can't dot because of invalid types or ranks" x y n))))
-
-(define (rank-of x)
- (if (number? x)
- 0
- (array-rank x)))
-
-;;;; differentiation
-
-(define-record-type internal
- (make-internal forward jacobian)
- internal?
- (forward internal-forward)
- (jacobian internal-jacobian))
-
-;;(define *atype* 'f32)
-(define *atype* #t)
-(define *grad* (make-parameter #f))
-(define *j* (make-parameter #f))
-
-(define-syntax-rule (w/j val body ...)
- (parameterize ((*j* val))
- body ...))
-
-(define (unwrap-fwd x)
- (if (internal? x)
- (internal-forward x)
- x))
-
-(define (unwrap-jac x)
- (if (internal? x)
- (internal-jacobian x)
- x))
-
-(define (dims-of x)
- (if (number? x)
- '()
- (array-dimensions x)))
-
-(define (mov dst-buf n-dst-dims generator naked-inputs data j)
- (do-times
- n-dst-dims
- (lambda (i)
- (array-set!
- dst-buf
- (apply generator naked-inputs i j data)
- i)))
- dst-buf)
-
-(define (add dst-buf n-dst-dims generator naked-inputs data j)
- (do-times
- n-dst-dims
- (lambda (i)
- (array-set!
- dst-buf
- (+ (array-ref dst-buf i)
- (apply generator naked-inputs i j data))
- i)))
- dst-buf)
-
-(define (movc dst-buf n-dst-dims src-buf n-src-dims
- generator naked-inputs data)
- "Contract the Jacobian column produced by the generator with the source buffer
-storing the result in the destination buffer."
- (let ((s 0))
- (do-times
- n-dst-dims
- (lambda (i)
- (set! s 0)
- (do-times
- n-src-dims
- (lambda (k)
- (set! s (+ s (* (apply generator naked-inputs i k data)
- (array-ref src-buf k))))))
- (array-set! dst-buf s i))))
- dst-buf)
-
-(define (addc dst-buf n-dst-dims src-buf n-src-dims
- generator naked-inputs data)
- "Contract the Jacobian column produced by the generator with the source buffer
-adding the result to the destination buffer."
- (let ((s 0))
- (do-times
- n-dst-dims
- (lambda (i)
- (set! s (array-ref dst-buf i))
- (do-times
- n-src-dims
- (lambda (k)
- (set! s (+ s (* (apply generator naked-inputs i k data)
- (array-ref src-buf k))))))
- (array-set! dst-buf s i))))
- dst-buf)
-
-(define* (grad f #:optional (axis 0))
- (define (wrap x i)
- (if (= i axis)
- (make-internal x 'input)
- x))
- (lambda xs
- (parameterize (((@@ (vouivre grad) *grad*) #t)
- ((@@ (vouivre grad) *promises*) (cons '() #f)))
- (let* ((internal (apply f (map-indexed wrap xs)))
- (fx (internal-forward internal))
- (y (list-ref xs axis)) ; variable to differentiate w.r.t
- (pre-Jx (internal-jacobian internal))
- (Jx (cond
- ;; TODO: implement 'input case and test 'zero and 'input
- ((eq? pre-Jx 'zero)
- (lambda (j)
- (lambda (i)
- 0)))
- ((eq? pre-Jx 'input)
- (error "TBD."))
- (else
- (lambda (j)
- (reset-promises (car (*promises*)))
- (let ((column-jac (w/j j (force pre-Jx))))
- (lambda (i)
- (array-ref column-jac i))))))))
- (cond
- ((and (number? fx) (number? y))
- ((Jx 0) 0))
- ((and (number? fx) (array? y))
- (let* ((y-dims (array-dimensions y))
- (a (apply make-array *unspecified* y-dims))
- (ac (array-contents a)))
- (do-times
- (apply * y-dims)
- (lambda (j)
- (array-set! ac ((Jx j) 0)
- j)))
- a))
- ((and (array? fx) (number? y))
- (let* ((fx-dims (array-dimensions fx))
- (a (apply make-array *unspecified* fx-dims))
- (ac (array-contents a))
- (Jx (Jx 0)))
- (do-times
- (apply * fx-dims)
- (lambda (i)
- (array-set! ac (Jx i)
- i)))
- a))
- (else
- (let* ((fx-dims (array-dimensions fx))
- (y-dims (array-dimensions y))
- (n-fx-dims (apply * fx-dims))
- (n-y-dims (apply * y-dims))
- (a (apply make-array *unspecified* (append fx-dims y-dims)))
- (ac (array-contents a)))
- (do-times
- n-y-dims
- (lambda (j)
- (let ((Jx (Jx j)))
- (do-times
- n-fx-dims
- (lambda (i)
- (array-set! ac (Jx i)
- (+ j (* n-y-dims i))))))))
- a)))))))
-
-
-;; In the comment that follows:
-;;
-;; `n' is the number of arguments to `proc'.
-;; `generators is not a `Vec' but a `List' we only use the former to illustrate
-;; its length.
-;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
-;; `I' is the type of multi-indices indexing the output of `function'.
-;; `J' is the type of multi-indices indexing the input array being differentiated.
-;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J').
-;; `Array I' is the type of arrays indexed by multi-indices of `I'.
-;; `[X]' means that `X' is boxed in an internal as when returned by
-;; `differentiable-wrapper' with the array being `X' and the promise that
-;; given a |J| we will get the change of `X' with a change of the
-;; the differentiated argument at multi-index `J'.
-;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number))
-;; (→ X1 ... Xn (Array I))*
-;; [X1] ... [Xn]
-;; (Internal (Array I) (Promise |J| (Array |I|)))))
-;;
-;; (*) We extend this definition to allow `proc' to be a list of procedures
-;; the head of which is as described above and the remaining elements
-;; are procedures of the same arguments but returning values that are
-;; then fed as extra data to the generators.
-;;
-;; NOTE: In cases where an argument isn't meant to be differentiable its
-;; corresponding generator should be `#f'.
-(define (differentiable-wrapper generators proc* arg . more)
- (let* ((args (cons arg more))
- (proc (if (procedure? proc*)
- proc*
- (car proc*)))
- (naked-args (map unwrap-fwd args))
- (out (apply proc naked-args)))
- (ifn (*grad*)
- out
- (let* ((data (if (procedure? proc*)
- '()
- (map (lambda (g)
- (apply g naked-args))
- (cdr proc*))))
- (n-out-dims (apply * (dims-of out)))
- (buf (make-array *unspecified* n-out-dims)))
- (make-internal
- out
- (fold
- (lambda (generator arg prev)
- (if (or (not (internal? arg))
- (eq? 'zero (internal-jacobian arg)))
- prev
- (let ((Jx (internal-jacobian arg))
- (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
- (if (eq? Jx 'input)
- (if (eq? prev 'zero)
- (delay
- (mov buf n-out-dims
- generator naked-args data (*j*)))
- (delay
- (add (force prev) n-out-dims
- generator naked-args data (*j*))))
- (if (eq? prev 'zero)
- (delay
- (movc buf n-out-dims (force Jx) n-fwd-dims
- generator naked-args data))
- (delay
- (addc (force prev) n-out-dims
- (force Jx) n-fwd-dims
- generator naked-args data)))))))
- 'zero generators args))))))
-
-(define (ewise1 f)
- (lambda (xs i j)
- (let ((x (car xs)))
- (if (number? x)
- (f x)
- (ifn (= i j)
- 0
- (f (array-ref (array-contents x)
- j)))))))
-
-(define (ewise2 proc axis)
- (lambda (xs i j)
- (let ((x (car xs))
- (y (cadr xs)))
- (cond
- ((and (number? x) (number? y))
- (proc x y))
- ((and (number? x) (array? y))
- (if (= axis 0)
- (proc x (array-ref (array-contents y)
- i))
- (ifn (= i j)
- 0
- (proc x (array-ref (array-contents y)
- j)))))
- ((and (array? x) (number? y))
- (if (= axis 1)
- (proc (array-ref (array-contents x)
- i)
- y)
- (ifn (= i j)
- 0
- (proc (array-ref (array-contents x)
- j)
- y))))
- (else
- (ifn (= i j)
- 0
- (proc (array-ref (array-contents x)
- j)
- (array-ref (array-contents y)
- j))))))))
-
-(define (i:identity x)
- "Differentiable identity."
- (differentiable-wrapper
- (list (ewise1 (lambda _ 1)))
- identity
- x))
-
-(define (i:sqrt x)
- "Differentiable square root."
- (differentiable-wrapper
- (list (ewise1 (lambda (x) (/ 1 2 (sqrt x)))))
- (extend sqrt)
- x))
-
-(define (i:exp x)
- "Differentiable exponential."
- (differentiable-wrapper
- (list (ewise1 exp))
- (extend exp)
- x))
-
-(define (i:expt x y)
- "Differentiable power."
- (differentiable-wrapper
- (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0)
- (ewise2 (lambda (x y) (* (expt x y) (log x))) 1))
- (extend expt)
- x y))
-
-(define (i:log x)
- "Differentiable logarithm."
- (differentiable-wrapper
- (list (ewise1 (lambda (x) (/ x))))
- (extend log)
- x))
-
-(define (i:sin x)
- "Differentiable sine."
- (differentiable-wrapper
- (list (ewise1 cos))
- (extend sin)
- x))
-
-(define (i:cos x)
- "Differentiable cosine."
- (differentiable-wrapper
- (list (ewise1 (lambda (x) (- (sin x)))))
- (extend cos)
- x))
-
-(define (i:tan x)
- "Differentiable tangent."
- (differentiable-wrapper
- (list (ewise1 (lambda (x) (/ (expt (cos x) 2)))))
- (extend tan)
- x))
-
-(define (i:+ x y)
- "Differentiable element-wise addition."
- (differentiable-wrapper
- (list
- (ewise2 (lambda _ +1) 0)
- (ewise2 (lambda _ +1) 1))
- (extend +)
- x y))
-
-(define (i:- x y)
- "Differentiable element-wise subtraction."
- (differentiable-wrapper
- (list
- (ewise2 (lambda _ +1) 0)
- (ewise2 (lambda _ -1) 1))
- (extend -)
- x y))
-
-(define (i:* x y)
- "Differentiable element-wise multiplication."
- (differentiable-wrapper
- (list
- (ewise2 (lambda (x y) y) 0)
- (ewise2 (lambda (x y) x) 1))
- (extend *)
- x y))
-
-(define (i:/ x y)
- "Differentiable element-wise division."
- (differentiable-wrapper
- (list
- (ewise2 (lambda (x y) (/ y)) 0)
- (ewise2 (lambda (x y) (- (/ x y y))) 1))
- (extend /)
- x y))
-
-(define (i:max x y)
- "Differentiable element-wise maximum."
- (define (dmax x y)
- (cond
- ((> x y)
- 1)
- ((= x y)
- 1/2)
- (else
- 0)))
- (differentiable-wrapper
- (list
- (ewise2 dmax 0)
- (ewise2 (flip dmax) 1))
- (extend max)
- x y))
-
-(define (i:min x y)
- "Differentiable element-wise minimum."
- (define (dmin x y)
- (cond
- ((< x y)
- 1)
- ((= x y)
- 1/2)
- (else
- 0)))
- (differentiable-wrapper
- (list
- (ewise2 dmin 0)
- (ewise2 (flip dmin) 1))
- (extend min)
- x y))
-
-(define (i:abs x)
- "Differentiable absolute."
- (differentiable-wrapper
- (list (ewise1 (lambda (x)
- (cond ((> x 0)
- +1)
- ((= x 0)
- 1/2)
- ((< x 0)
- -1)))))
- (extend abs)
- x))
-
-(define (mean x)
- "Differentiable mean on arrays."
- (differentiable-wrapper
- (list
- (lambda (xs i j one-over-n)
- one-over-n))
- (let ((n 0))
- (list
- (lambda (x)
- (let ((sum 0))
- (array-for-each
- (lambda (x)
- (set! sum (+ sum x))
- (set! n (1+ n)))
- x)
- (/ sum n)))
- (lambda _ (/ n))))
- x))
-
-(define (i:array-ref x . indices)
- "Differentiable array-ref w.r.t `x'."
- (apply
- differentiable-wrapper
- (cons
- (lambda (xs i j abs-index)
- (if (= j abs-index)
- 1
- 0))
- (map not indices))
- (list
- array-ref
- (lambda (x . indices)
- (rel->abs indices (array-dimensions x))))
- x indices))
-
-(define (i:array-cell-ref x . indices)
- (apply
- differentiable-wrapper
- (cons
- (lambda (xs i j abs-index n-rst-dims)
- (receive (j-ref j-rst) (euclidean/ j n-rst-dims)
- (if (and (= j-ref abs-index)
- (= j-rst i))
- 1
- 0)))
- (map not indices))
- (list
- array-cell-ref
- (lambda (x . indices)
- (rel->abs indices (take (array-dimensions x)
- (length indices))))
- (lambda (x . indices)
- (apply * (drop (array-dimensions x)
- (length indices)))))
- x indices))
-
-(define (make-batch elem . more)
- (let ((batch-size (1+ (length more))))
- (apply
- differentiable-wrapper
- (list-tabulate
- batch-size
- (lambda (b)
- (lambda (xs i j n-rest-dims)
- (receive (i-batch i-rest) (euclidean/ i n-rest-dims)
- (if (and (= i-batch b)
- (= i-rest j))
- 1
- 0
- )))))
- (list
- (lambda (elem . more)
- (let ((a (apply make-typed-array *atype* *unspecified* batch-size
- (dims-of elem))))
- (for-each
- (lambda (x b)
- (array-cell-set! a x b))
- (cons elem more)
- (list-tabulate batch-size identity))
- a))
- (lambda (elem . more)
- (apply * (dims-of elem))))
- elem more)))
-
-(define (maximum x)
- "Differentiable maximum on arrays."
- (differentiable-wrapper
- (list
- (lambda (xs i j max-index)
- (if (= j max-index)
- 1
- 0)))
- (let ((max-index 'TBD))
- (list
- (lambda (x)
- (let ((m (- (inf)))
- (i 0))
- (array-for-each
- (lambda (x)
- (when (< m x)
- (set! m x)
- (set! max-index i))
- (set! i (1+ i)))
- x)
- m))
- (lambda _ max-index)))
- x))
-
-(define (sum x)
- "Differentiable sum on arrays."
- (differentiable-wrapper
- (list (lambda _ 1))
- (lambda (x)
- (let ((sum 0))
- (array-for-each
- (lambda (x)
- (set! sum (+ sum x)))
- x)
- sum))
- x))
-
-(define (adot x y n)
- (differentiable-wrapper
- (list
- (lambda (xs i j n-free-dims-y n-bound-dims)
- (receive (i-x i-y) (euclidean/ i n-free-dims-y)
- (receive (j-free j-bound) (euclidean/ j n-bound-dims)
- (ifn (= i-x j-free)
- 0
- (array-ref (array-contents (cadr xs))
- (+ i-y (* n-free-dims-y j-bound)))))))
- (lambda (xs i j n-free-dims-y n-bound-dims)
- (receive (i-x i-y) (euclidean/ i n-free-dims-y)
- (receive (j-bound j-free) (euclidean/ j n-free-dims-y)
- (ifn (= i-y j-free)
- 0
- (array-ref (array-contents (car xs))
- (+ j-bound (* n-bound-dims i-x)))))))
- #f)
- (list
- contract-arrays
- (lambda (x y n)
- (apply * (drop (array-dimensions y)
- n)))
- (lambda (x y n)
- (apply * (take (array-dimensions y)
- n))))
- x y n))
-
-(define (amap2 f x y)
- (apply make-batch
- (list-tabulate (car (dims-of (unwrap-fwd x)))
- (lambda (b)
- (f (i:array-cell-ref x b)
- (i:array-cell-ref y b))))))