]> git.vouivredigital.com Git - vouivre.git/commitdiff
Implement reverse mode automatic differentiation
authoradmin <admin@vouivredigital.com>
Thu, 23 Nov 2023 07:17:16 +0000 (16:17 +0900)
committeradmin <admin@vouivredigital.com>
Thu, 23 Nov 2023 07:17:16 +0000 (16:17 +0900)
autodiff-tests.scm [new file with mode: 0644]
autodiff.scm [new file with mode: 0644]
grad-tests.scm [deleted file]
grad.scm [deleted file]

diff --git a/autodiff-tests.scm b/autodiff-tests.scm
new file mode 100644 (file)
index 0000000..f1fc168
--- /dev/null
@@ -0,0 +1,351 @@
+(define-module (vouivre autodiff tests)
+  #:use-module ((vouivre autodiff) #:prefix v:)
+  #:use-module (ice-9 receive)
+  #:use-module (srfi srfi-1)
+  #:use-module (srfi srfi-64)
+  #:use-module (vouivre misc)
+  #:export
+  (apply-diff
+   a~
+   const-generator
+   differentiable-func-generator
+   lambda-const-call
+   ndiff
+   n~
+   random-array
+   random-array-shape
+   random-func1
+   random-func2
+   random-func2-rank&dims>0
+   random-input
+   random-list-element
+   random-non-empty-array
+   random-shape
+   random-shared
+   random-shared-array-rank&dims>0
+   random-shared-contractible
+   with-generators
+   ~))
+
+(define f1s (list v:abs v:cos v:exp v:identity v:sin))
+(define f2s (list v:+ v:- v:* v:max v:min))
+
+(define (with-generators% generators equal proc1 proc2 . more)
+  "Check that all procedures return the same value according to `equal' when
+evaluated on arguments produced by the generators (the number of generators
+being the number of arguments to each procedure."
+  (let ((times 100)
+       (procs (cons proc1 (cons proc2 more))))
+    (call/cc
+     (lambda (break)
+       (do ((i 0 (1+ i)))
+          ((= i times) #t)
+        (let ((zs (map-in-order (lambda (g) (g)) generators)))
+          (with-exception-handler
+              (lambda (e)
+                (break #f zs))
+            (lambda ()
+              (let* ((rs (map (lambda (f) (apply f zs)) procs))
+                     (head (car rs)))
+                (unless (every (lambda (x) (equal x head))
+                               (cdr rs))
+                  (break #f zs rs))))
+            #:unwind? #t)))))))
+
+(define-syntax-rule (with-generators (g1 g2 ...) equal expected given more ...)
+  (with-generators% (list g1 g2 ...) equal expected given more ...))
+
+(define (lambda-const-call f . consts)
+  (lambda _
+    (apply f consts)))
+
+(define* (random-array-shape
+         #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5))
+  (list-tabulate (+ min-rank (random (- max-rank min-rank)))
+                (lambda _ (+ min-dim (random (- max-dim min-dim))))))
+
+(define (random-shape)
+  (if (= 0 (random 2))
+      0
+      (random-array-shape)))
+
+(define* (random-array #:optional shape)
+  (apply produce-typed-array
+        (lambda _ (random:uniform))
+        v:*atype* (or shape (random-array-shape))))
+
+(define (random-non-empty-array)
+  "Random array of at least one element."
+  (random-array (random-array-shape 0 5 1 5)))
+
+(define* (random-input #:optional shape)
+  (let ((shape (or shape (random-shape))))
+    (if (eq? 0 shape)
+       (random:uniform)
+       (random-array shape))))
+
+(define (random-shared)
+  (let ((shape (random-shape)))
+    (values
+     (lambda ()
+       (random-input shape))
+     (lambda ()
+       (let ((x (random-input
+                (random-list-element
+                 (list 0 (if (list? shape)
+                             shape
+                             (random-shape)))))))
+        (set! shape (random-shape))
+        x)))))
+
+(define (random-shared-array-rank&dims>0)
+  (let ((shape (random-array-shape 1 5 1 5)))
+    (values
+     (lambda ()
+       (random-array shape))
+     (lambda ()
+       (let ((x (random-array shape)))
+        (set! shape (random-array-shape 1 5 1 5))
+        x)))))
+
+(define (random-list-element lst)
+  (list-ref lst (random (length lst))))
+
+(define (const-generator generator)
+  (lambda ()
+    generator))
+
+(define (differentiable-func-generator lst . input-generators)
+  (lambda ()
+    (random-list-element
+     (cons
+      (apply
+       lambda-const-call
+       (random-list-element lst)
+       (map (lambda (g) (g))
+           input-generators))
+      lst))))
+
+(define random-func1
+  (differentiable-func-generator f1s random-input))
+(define random-func2
+  (receive (gx gy) (random-shared)
+    (differentiable-func-generator f2s gx gy)))
+
+(define* (n~ x y #:optional (error 1e-4))
+  (and
+   (>= y (- x error))
+   (<= y (+ x error))))
+
+(define* (a~ x y #:optional (error 1e-4))
+  (and
+   (equal? (array-dimensions x)
+          (array-dimensions y))
+   (call/cc
+    (lambda (break)
+      (array-for-each
+       (lambda (x y)
+        (unless (~ x y error)
+          (break #f)))
+       x y)
+      #t))))
+
+(define* (~ x y #:optional (error 1e-4))
+  (cond
+   ((and (number? x) (number? y))
+    (n~ x y error))
+   ((and (array? x) (array? y))
+    (a~ x y error))
+   (else #f)))
+
+(define* (ndiff f #:optional (axis 0) (step 1e-6))
+  "Differentiation using a numerical centered difference approximation."
+  (define (axis-add xs dh . indices)
+    "Add `dh' to the number or array at the given `axis' of `xs',
+and, when it's an array, at the given index."
+    (map-indexed
+     (lambda (x i)
+       (ifn (= i axis)
+           x
+           (if (number? x)
+               (+ x dh)
+               (array-map-indexed
+                (lambda (x . indices_)
+                  (ifn (equal? indices indices_)
+                       x
+                       (+ x dh)))
+                x))))
+     xs))
+  (lambda xs
+    ;; We need the output shape and the input shape along the
+    ;; differentiated axis.
+    (let ((fxs (apply f xs))
+         (x (list-ref xs axis)))
+      (cond
+       ((and (number? fxs)
+            (number? x))
+       (/ (- (apply f (axis-add xs step))
+             (apply f (axis-add xs (- step))))
+          (* 2 step)))
+       ((and (number? fxs)
+            (array? x))
+       (apply
+        produce-typed-array
+        (lambda indices
+          (/ (- (apply f (apply axis-add xs step indices))
+                (apply f (apply axis-add xs (- step) indices)))
+             (* 2 step)))
+        v:*atype*
+        (array-dimensions x)))
+       ((and (array? fxs)
+            (number? x))
+       ((v:extend /)
+        ((v:extend -)
+         (apply f (axis-add xs step))
+         (apply f (axis-add xs (- step))))
+        (* 2 step)))
+       ((and (array? fxs)
+            (array? x))
+       (let ((a (apply
+                 make-typed-array v:*atype* *unspecified*
+                 (append (array-dimensions fxs)
+                         (array-dimensions x)))))
+         (for-indices-in-range
+          (lambda indices-in
+            (let ((dfxs ((v:extend /)
+                         ((v:extend -)
+                          (apply f (apply axis-add xs step indices-in))
+                          (apply f (apply axis-add xs (- step) indices-in)))
+                         (* 2 step))))
+              (for-indices-in-range
+               (lambda indices-out
+                 (apply
+                  array-set!
+                  a
+                  (apply array-ref dfxs indices-out)
+                  (append indices-out indices-in)))
+               (list-zeros (array-rank fxs))
+               (array-dimensions fxs))))
+          (list-zeros (array-rank x))
+          (array-dimensions x))
+         a))))))
+
+(define* (apply-diff differentiator #:optional (axis 0))
+  "Apply a differentiator (`ndiff', `fdiff', `rdiff') to a function and its
+arguments (this is a convenience function)."
+  (lambda (f . args)
+    (apply (differentiator f axis) args)))
+
+(test-begin "autodiff")
+
+;; not differentiating
+(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity))
+(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp))
+(test-assert
+    (receive (gx gy) (random-shared)
+      (with-generators (gx gy) ~ (v:extend *) v:*)))
+
+;; differentiation in one variable
+(test-assert
+    (with-generators
+     (random-func1 random-input)
+     ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff)))
+
+;; `v:mean' only takes non-empty arrays so we treat it separately
+(test-assert
+    (with-generators
+     ((differentiable-func-generator (list v:mean) random-non-empty-array)
+      random-non-empty-array)
+     ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff)))
+
+;; differentiation in two variables
+(test-assert
+    (receive (gx gy) (random-shared)
+      (with-generators
+       (random-func2 gx gy)
+       ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0))))
+(test-assert
+    (receive (gx gy) (random-shared)
+      (with-generators
+       (random-func2 gx gy)
+       ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1))))
+
+;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it
+;; separately
+(define random-func2-rank&dims>0
+  (receive (gx gy) (random-shared-array-rank&dims>0)
+    (differentiable-func-generator f2s gx gy)))
+(test-assert
+    (receive (gx gy) (random-shared-array-rank&dims>0)
+      (with-generators
+       ((const-generator v:amap2) random-func2-rank&dims>0 gx gy)
+       ;; NOTE: for `v:amap2' the differentiable axes are 1 and 2.
+       ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1))))
+(test-assert
+    (receive (gx gy) (random-shared-array-rank&dims>0)
+      (with-generators
+       ((const-generator v:amap2) random-func2-rank&dims>0 gx gy)
+       ~ (apply-diff ndiff 2) (apply-diff v:fdiff 2) (apply-diff v:rdiff 2))))
+(let* ((z #(1 2 3))
+       (f (lambda (a)
+           (v:amap2 (lambda (x y)
+                      (v:* a a))
+                    #(10 20 30)
+                    #(40 50 60))))
+       (e ((ndiff f) z)))
+  (test-assert (~ e ((v:fdiff f) z)))
+  (test-assert (~ e ((v:rdiff f) z))))
+
+;; `v:adot'
+(define (random-shared-contractible)
+  "Returns three generators: the first two generate arrays that are contractible
+according to the number generated by the third one."
+  (let* ((n (random 5))
+        (sa (random-array-shape n))
+        (sb (append (reverse (take (reverse sa)
+                                   n))
+                    (random-array-shape 0 (- 5 n)))))
+    (values
+     (lambda ()
+       (random-array sa))
+     (lambda ()
+       (random-array sb))
+     (lambda ()
+       (let ((tmp n))
+        (set! n (random 5))
+        (set! sa (random-array-shape n))
+        (set! sb (append (reverse (take (reverse sa)
+                                        n))
+                         (random-array-shape 0 (- 5 n))))
+        tmp)))))
+(test-assert
+    (receive (gx gy gz) (random-shared-contractible)
+      (with-generators
+       ((const-generator v:adot) gx gy gz)
+       ~ (apply-diff ndiff 0) (apply-diff v:fdiff 0) (apply-diff v:rdiff 0))))
+(test-assert
+    (receive (gx gy gz) (random-shared-contractible)
+      (with-generators
+       ((const-generator v:adot) gx gy gz)
+       ~ (apply-diff ndiff 1) (apply-diff v:fdiff 1) (apply-diff v:rdiff 1))))
+
+;; let binding re-entry
+(test-assert
+    (with-generators
+     ((const-generator
+       (lambda (x)
+        (let ((c (v:maximum x)))
+          (v:+ c (v:- x c)))))
+      random-non-empty-array)
+     ~ (apply-diff ndiff) (apply-diff v:fdiff) (apply-diff v:rdiff)))
+
+;; chain rule
+(test-assert
+    (with-generators
+     (random-func1 random-func1 random-input)
+     ~
+     (lambda (f g x) ((ndiff (compose f g)) x))
+     (lambda (f g x) ((v:fdiff (compose f g)) x))
+     (lambda (f g x) ((v:rdiff (compose f g)) x))))
+
+(test-end "autodiff")
diff --git a/autodiff.scm b/autodiff.scm
new file mode 100644 (file)
index 0000000..a6666ad
--- /dev/null
@@ -0,0 +1,859 @@
+(define-module (vouivre autodiff)
+  #:use-module (ice-9 receive)
+  #:use-module (srfi srfi-1)
+  #:use-module (srfi srfi-9)
+  #:use-module (vouivre misc)
+  #:use-module (vouivre promises)
+  #:export
+  (*atype*
+   adot
+   amap2
+   contract-arrays
+   differentiable-wrapper
+   dot
+   do-times
+   ewise1
+   ewise2
+   extend
+   fdiff
+   rdiff
+   make-batch
+   make-internal
+   maximum
+   mean
+   rank-of
+   sum)
+  #:replace
+  ((i:sqrt . sqrt)
+   (i:exp . exp)
+   (i:expt . expt)
+   (i:log . log)
+   (i:sin . sin)
+   (i:cos . cos)
+   (i:tan . tan)
+   (i:+ . +)
+   (i:- . -)
+   (i:* . *)
+   (i:/ . /)
+   (i:max . max)
+   (i:min . min)
+   (i:abs . abs)
+   (i:identity . identity)
+   (i:array-ref . array-ref)
+   (i:array-cell-ref . array-cell-ref))
+  #:re-export
+  (fold
+   reduce))
+
+;;;; array utilities
+
+(define (rel->abs indices dimensions)
+  (let rec ((s 0)
+           (p 1)
+           (is (reverse indices))
+           (ds (reverse dimensions)))
+    (if (null? is)
+       s
+       (rec (+ s (* p (car is)))
+            (* p (car ds))
+            (cdr is)
+            (cdr ds)))))
+
+(define(do-times n proc)
+  (let rec ((i 0))
+    (unless (= i n)
+      (proc i)
+      (rec (1+ i)))))
+
+(define (contract-arrays a b n)
+  (let* ((dims-a (array-dimensions a))
+        (dims-b (array-dimensions b))
+        (free-dims-a (take dims-a (- (array-rank a) n)))
+        (free-dims-b (drop dims-b n))
+        (bound-dims (take dims-b n))
+        (n-free-dims-a (apply * free-dims-a))
+        (n-free-dims-b (apply * free-dims-b))
+        (n-bound-dims (apply * bound-dims))
+        (s 0)
+        (r (apply make-typed-array *atype* *unspecified* (append free-dims-a
+                                                                 free-dims-b)))
+        (ac (array-contents a))
+        (bc (array-contents b))
+        (rc (array-contents r)))
+    (do-times
+     n-free-dims-a
+     (lambda (i)
+       (let ((i-k (* n-bound-dims i))
+            (i-j (* n-free-dims-b i)))
+        (do-times
+         n-free-dims-b
+         (lambda (j)
+           (set! s 0)
+           (do-times
+            n-bound-dims
+            (lambda (k)
+              (set! s (+ s (* (array-ref ac (+ i-k k))
+                              (array-ref bc (+ (* n-free-dims-b k) j)))))))
+           (array-set! rc s (+ i-j j)))))))
+    r))
+
+;;;; utilities that work on both numbers and arrays
+
+(define (extend f)
+  "Extend a function of one or more scalars to apply to numbers/arrays
+element-wise. All arrays must have the same dimension."
+  (define (apply-elemwise f indices args)
+    (apply f (map (lambda (x)
+                   (if (number? x)
+                       x
+                       (apply array-ref x indices)))
+                 args)))
+  (lambda xs
+    (if-let (x (find array? xs))
+           (apply
+            produce-typed-array
+            (lambda is
+              (apply-elemwise f is xs))
+            *atype*
+            (array-dimensions x))
+           (apply f xs))))
+
+(define (dot x y n)
+  (cond
+   ((and (number? x) (number? y))
+    (* x y))
+   ((and (array? x) (array? y))
+    (contract-arrays x y n))
+   ((and (array? x) (number? y))
+    ((extend *) x y))
+   ((and (number? x) (array? y))
+    ((extend *) x y))
+   (else (error "can't dot because of invalid types or ranks" x y n))))
+
+(define (rank-of x)
+  (if (number? x)
+      0
+      (array-rank x)))
+
+;;;; differentiation
+
+(define-record-type internal
+  (make-internal forward jacobian)
+  internal?
+  (forward internal-forward)
+  (jacobian internal-jacobian))
+
+;;(define *atype* 'f32)
+(define *atype* #t)
+(define *differentiation-mode* (make-parameter #f))
+(define *n-y-dims* (make-parameter #f))
+(define *j* (make-parameter #f))
+
+(define-syntax-rule (w/j val body ...)
+  (parameterize ((*j* val))
+    body ...))
+
+(define (wrap axis)
+  (lambda (x i)
+    (if (= i axis)
+       (make-internal x 'input)
+       x)))
+
+(define (unwrap-fwd x)
+  (if (internal? x)
+      (internal-forward x)
+      x))
+
+(define (unwrap-jac x)
+  (if (internal? x)
+      (internal-jacobian x)
+      x))
+
+(define (dims-of x)
+  (if (number? x)
+      '()
+      (array-dimensions x)))
+
+(define (add dst-buf src-buf n-dims)
+  (do-times
+   n-dims
+   (lambda (i)
+     (array-set!
+      dst-buf
+      (+ (array-ref dst-buf i)
+        (array-ref src-buf i))
+      i)))
+  dst-buf)
+
+(define (movg dst-buf n-dst-dims generator naked-inputs data j)
+  (do-times
+   n-dst-dims
+   (lambda (i)
+     (array-set!
+      dst-buf
+      (apply generator naked-inputs i j data)
+      i)))
+  dst-buf)
+
+(define (addg dst-buf n-dst-dims generator naked-inputs data j)
+  (do-times
+   n-dst-dims
+   (lambda (i)
+     (array-set!
+      dst-buf
+      (+ (array-ref dst-buf i)
+        (apply generator naked-inputs i j data))
+      i)))
+  dst-buf)
+
+(define (movc dst-buf n-dst-dims src-buf n-src-dims
+             generator naked-inputs data)
+  "Contract the Jacobian column produced by the generator with the source buffer
+storing the result in the destination buffer."
+  (let ((s 0))
+    (do-times
+     n-dst-dims
+     (lambda (i)
+       (set! s 0)
+       (do-times
+       n-src-dims
+       (lambda (k)
+         (set! s (+ s (* (apply generator naked-inputs i k data)
+                         (array-ref src-buf k))))))
+       (array-set! dst-buf s i))))
+  dst-buf)
+
+(define (addc dst-buf n-dst-dims src-buf n-src-dims
+             generator naked-inputs data)
+  "Contract the Jacobian column produced by the generator with the source buffer
+adding the result to the destination buffer."
+  (let ((s 0))
+    (do-times
+     n-dst-dims
+     (lambda (i)
+       (set! s (array-ref dst-buf i))
+       (do-times
+       n-src-dims
+       (lambda (k)
+         (set! s (+ s (* (apply generator naked-inputs i k data)
+                         (array-ref src-buf k))))))
+       (array-set! dst-buf s i))))
+  dst-buf)
+
+(define (transpose-generator generator)
+  (lambda (xs i j . data)
+    (apply generator xs j i data)))
+
+(define* (fdiff f #:optional (axis 0))
+  (lambda xs
+    (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'fwd)
+                  ((@@ (vouivre autodiff) *promises*) (cons '() #f)))
+      (let* ((internal (apply f (map-indexed (wrap axis) xs)))
+            (fx (internal-forward internal))
+            (y (list-ref xs axis))  ; variable to differentiate w.r.t
+            (pre-Jx (internal-jacobian internal))
+            (Jx (cond
+                 ;; TODO: implement 'input case and test 'zero and 'input
+                 ((eq? pre-Jx 'zero)
+                  (lambda (j)
+                    (lambda (i)
+                      0)))
+                 ((eq? pre-Jx 'input)
+                  (error "TBD."))
+                 (else
+                  (lambda (j)
+                    (reset-promises (car (*promises*)))
+                    (let ((column-jac (w/j j (force pre-Jx))))
+                      (lambda (i)
+                        (array-ref column-jac i))))))))
+       (cond
+        ((and (number? fx) (number? y))
+         ((Jx 0) 0))
+        ((and (number? fx) (array? y))
+         (let* ((y-dims (array-dimensions y))
+                (a (apply make-array *unspecified* y-dims))
+                (ac (array-contents a)))
+           (do-times
+            (apply * y-dims)
+            (lambda (j)
+              (array-set! ac ((Jx j) 0)
+                          j)))
+           a))
+        ((and (array? fx) (number? y))
+         (let* ((fx-dims (array-dimensions fx))
+                (a (apply make-array *unspecified* fx-dims))
+                (ac (array-contents a))
+                (Jx (Jx 0)))
+           (do-times
+            (apply * fx-dims)
+            (lambda (i)
+              (array-set! ac (Jx i)
+                          i)))
+           a))
+        (else
+         (let* ((fx-dims (array-dimensions fx))
+                (y-dims (array-dimensions y))
+                (n-fx-dims (apply * fx-dims))
+                (n-y-dims (apply * y-dims))
+                (a (apply make-array *unspecified* (append fx-dims y-dims)))
+                (ac (array-contents a)))
+           (do-times
+            n-y-dims
+            (lambda (j)
+              (let ((Jx (Jx j)))
+                (do-times
+                 n-fx-dims
+                 (lambda (i)
+                   (array-set! ac (Jx i)
+                               (+ j (* n-y-dims i))))))))
+           a)))))))
+
+(define* (rdiff f #:optional (axis 0))
+  (lambda xs
+    (parameterize (((@@ (vouivre autodiff) *differentiation-mode*) 'rev)
+                  ((@@ (vouivre autodiff) *promises*) (cons '() #f)))
+      (let* ((internal (apply f (map-indexed (wrap axis) xs)))
+            (fx (internal-forward internal))
+            (y (list-ref xs axis))  ; variable to differentiate w.r.t
+            (y-dims (dims-of y))
+            (pre-Jx (internal-jacobian internal))
+            (Jx (cond
+                 ;; TODO: implement 'input case and test 'zero and 'input
+                 ((eq? pre-Jx 'zero)
+                  (lambda (i)
+                    (lambda (j)
+                      0)))
+                 ((eq? pre-Jx 'input)
+                  (error "TBD."))
+                 (else
+                  (let ((pre-Jx (pre-Jx #f)))
+                    (lambda (i)
+                      (let ((row-jac (pre-Jx i)))
+                        (lambda (j)
+                          (array-ref row-jac j)))))))))
+       (parameterize ((*n-y-dims* (apply * y-dims)))
+         (cond
+          ((and (number? fx) (number? y))
+           ((Jx 0) 0))
+          ((and (number? fx) (array? y))
+           (let* ((a (apply make-array *unspecified* y-dims))
+                  (ac (array-contents a))
+                  (Jx (Jx 0)))
+             (do-times
+              (*n-y-dims*)
+              (lambda (j)
+                (array-set! ac (Jx j)
+                            j)))
+             a))
+          ((and (array? fx) (number? y))
+           (let* ((fx-dims (array-dimensions fx))
+                  (a (apply make-array *unspecified* fx-dims))
+                  (ac (array-contents a)))
+             (do-times
+              (apply * fx-dims)
+              (lambda (i)
+                (array-set! ac ((Jx i) 0)
+                            i)))
+             a))
+          (else
+           (let* ((fx-dims (array-dimensions fx))
+                  (n-fx-dims (apply * fx-dims))
+                  (a (apply make-array *unspecified* (append fx-dims y-dims)))
+                  (ac (array-contents a)))
+             (do-times
+              n-fx-dims
+              (lambda (i)
+                (let ((Jx (Jx i)))
+                  (do-times
+                   (*n-y-dims*)
+                   (lambda (j)
+                     (array-set! ac (Jx j)
+                                 (+ j (* (*n-y-dims*) i))))))))
+             a))))))))
+
+;; In the comment that follows:
+;;
+;; `n' is the number of arguments to `proc'.
+;; `generators is not a `Vec' but a `List' we only use the former to illustrate
+;; its length.
+;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
+;; `I' is the type of multi-indices indexing the output of `function'.
+;; `J' is the type of multi-indices indexing the input array being differentiated.
+;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J').
+;; `Array I' is the type of arrays indexed by multi-indices of `I'.
+;; `[X]' means that `X' is boxed in an internal as when returned by
+;;       `differentiable-wrapper' with the array being `X' and the promise that
+;;       given a |J| we will get the change of `X' with a change of the
+;;       the differentiated argument at multi-index `J'.
+;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number))
+;;       (→ X1 ... Xn (Array I))*
+;;       [X1] ... [Xn]
+;;       (Internal (Array I) (Promise |J| (Array |I|)))))
+;;
+;; (*) We extend this definition to allow `proc' to be a list of procedures
+;;     the head of which is as described above and the remaining elements
+;;     are procedures of the same arguments but returning values that are
+;;     then fed as extra data to the generators.
+;;
+;; NOTE: In cases where an argument isn't meant to be differentiable its
+;;       corresponding generator should be `#f'.
+(define (differentiable-wrapper generators proc* arg . more)
+  (define (precompute-data naked-args)
+    (if (procedure? proc*)
+       '()
+       (map (lambda (g)
+              (apply g naked-args))
+            (cdr proc*))))
+  (let* ((args (cons arg more))
+        (proc (if (procedure? proc*)
+                  proc*
+                  (car proc*)))
+        (naked-args (map unwrap-fwd args))
+        (out (apply proc naked-args)))
+    (case (*differentiation-mode*)
+      ((#f)
+       out)
+      ((fwd)
+       (let* ((data (precompute-data naked-args))
+             (n-out-dims (apply * (dims-of out)))
+             (buf (make-array *unspecified* n-out-dims)))
+        (make-internal
+         out
+         (fold
+          (lambda (generator arg prev)
+            (if (or (not (internal? arg))
+                    (eq? 'zero (internal-jacobian arg)))
+                prev
+                (let ((Jx (internal-jacobian arg))
+                      (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
+                  (if (eq? Jx 'input)
+                      (if (eq? prev 'zero)
+                          (delay
+                            (movg buf n-out-dims
+                                  generator naked-args data (*j*)))
+                          (delay
+                            (addg (force prev) n-out-dims
+                                  generator naked-args data (*j*))))
+                      (if (eq? prev 'zero)
+                          (delay
+                            (movc buf n-out-dims (force Jx) n-fwd-dims
+                                  generator naked-args data))
+                          (delay
+                            (addc (force prev) n-out-dims
+                                  (force Jx) n-fwd-dims
+                                  generator naked-args data)))))))
+          'zero generators args))))
+      ((rev)
+       (let ((data (precompute-data naked-args))
+            (n-out-dims (apply * (dims-of out))))
+        (make-internal
+         out
+         (fold
+          (lambda (generator arg prev)
+            (let ((generator (transpose-generator generator)))
+              (if (or (not (internal? arg))
+                      (eq? 'zero (internal-jacobian arg)))
+                  prev
+                  (let* ((Jx (internal-jacobian arg))
+                         (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
+                    (if (eq? Jx 'input)
+                        (if (eq? prev 'zero)
+                            (lambda (buf?)
+                              (let ((dst-buf (make-array *unspecified*
+                                                         n-fwd-dims)))
+                                (if buf?
+                                    (lambda (buf)
+                                      (movc dst-buf n-fwd-dims
+                                            buf n-out-dims
+                                            generator naked-args data))
+                                    (lambda (i)
+                                      (movg dst-buf n-fwd-dims
+                                            generator naked-args data
+                                            i)))))
+                            (lambda (buf?)
+                              (let ((prev (prev buf?)))
+                                (if buf?
+                                    (lambda (buf)
+                                      (addc (prev buf) n-fwd-dims
+                                            buf n-out-dims
+                                            generator naked-args data))
+                                    (lambda (i)
+                                      (addg (prev i) n-fwd-dims
+                                            generator naked-args data
+                                            i))))))
+                        (if (eq? prev 'zero)
+                            (lambda (buf?)
+                              (let ((Jx (Jx #t))
+                                    (dst-buf (make-array *unspecified*
+                                                         n-fwd-dims)))
+                                (if buf?
+                                    (lambda (buf)
+                                      (Jx
+                                       (movc dst-buf n-fwd-dims buf
+                                             n-out-dims
+                                             generator naked-args data)))
+                                    (lambda (i)
+                                      (Jx
+                                       (movg dst-buf n-fwd-dims
+                                             generator naked-args data
+                                             i))))))
+                            (lambda (buf?)
+                              (let ((prev (prev buf?))
+                                    (Jx (Jx #t))
+                                    (dst-buf (make-array *unspecified*
+                                                         n-fwd-dims)))
+                                (if buf?
+                                    (lambda (buf)
+                                      (add (prev buf)
+                                           (Jx
+                                            (movc dst-buf n-fwd-dims
+                                                  buf n-out-dims
+                                                  generator naked-args data))
+                                           (*n-y-dims*)))
+                                    (lambda (i)
+                                      (add (prev i)
+                                           (Jx
+                                            (movg dst-buf n-fwd-dims
+                                                  generator naked-args data
+                                                  i))
+                                           (*n-y-dims*))))))))))))
+          'zero generators args)))))))
+
+(define (ewise1 f)
+  (lambda (xs i j)
+    (let ((x (car xs)))
+      (if (number? x)
+         (f x)
+         (ifn (= i j)
+              0
+              (f (array-ref (array-contents x)
+                            j)))))))
+
+(define (ewise2 proc axis)
+  (lambda (xs i j)
+    (let ((x (car xs))
+         (y (cadr xs)))
+      (cond
+       ((and (number? x) (number? y))
+       (proc x y))
+       ((and (number? x) (array? y))
+       (if (= axis 0)
+           (proc x (array-ref (array-contents y)
+                              i))
+           (ifn (= i j)
+                0
+                (proc x (array-ref (array-contents y)
+                                   j)))))
+       ((and (array? x) (number? y))
+       (if (= axis 1)
+           (proc (array-ref (array-contents x)
+                            i)
+                 y)
+           (ifn (= i j)
+                0
+                (proc (array-ref (array-contents x)
+                                 j)
+                      y))))
+       (else
+       (ifn (= i j)
+            0
+            (proc (array-ref (array-contents x)
+                             j)
+                  (array-ref (array-contents y)
+                             j))))))))
+
+(define (i:identity x)
+  "Differentiable identity."
+  (differentiable-wrapper
+   (list (ewise1 (lambda _ 1)))
+   identity
+   x))
+
+(define (i:sqrt x)
+  "Differentiable square root."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (/ 1 2 (sqrt x)))))
+   (extend sqrt)
+   x))
+
+(define (i:exp x)
+  "Differentiable exponential."
+  (differentiable-wrapper
+   (list (ewise1 exp))
+   (extend exp)
+   x))
+
+(define (i:expt x y)
+  "Differentiable power."
+  (differentiable-wrapper
+   (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0)
+        (ewise2 (lambda (x y) (* (expt x y) (log x))) 1))
+   (extend expt)
+   x y))
+
+(define (i:log x)
+  "Differentiable logarithm."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (/ x))))
+   (extend log)
+   x))
+
+(define (i:sin x)
+  "Differentiable sine."
+  (differentiable-wrapper
+   (list (ewise1 cos))
+   (extend sin)
+   x))
+
+(define (i:cos x)
+  "Differentiable cosine."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (- (sin x)))))
+   (extend cos)
+   x))
+
+(define (i:tan x)
+  "Differentiable tangent."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (/ (expt (cos x) 2)))))
+   (extend tan)
+   x))
+
+(define (i:+ x y)
+  "Differentiable element-wise addition."
+  (differentiable-wrapper
+   (list
+    (ewise2 (lambda _ +1) 0)
+    (ewise2 (lambda _ +1) 1))
+   (extend +)
+   x y))
+
+(define (i:- x y)
+  "Differentiable element-wise subtraction."
+  (differentiable-wrapper
+   (list
+    (ewise2 (lambda _ +1) 0)
+    (ewise2 (lambda _ -1) 1))
+   (extend -)
+   x y))
+
+(define (i:* x y)
+  "Differentiable element-wise multiplication."
+  (differentiable-wrapper
+   (list
+    (ewise2 (lambda (x y) y) 0)
+    (ewise2 (lambda (x y) x) 1))
+   (extend *)
+   x y))
+
+(define (i:/ x y)
+  "Differentiable element-wise division."
+  (differentiable-wrapper
+   (list
+    (ewise2 (lambda (x y) (/ y)) 0)
+    (ewise2 (lambda (x y) (- (/ x y y))) 1))
+   (extend /)
+   x y))
+
+(define (i:max x y)
+  "Differentiable element-wise maximum."
+  (define (dmax x y)
+    (cond
+     ((> x y)
+      1)
+     ((= x y)
+      1/2)
+     (else
+      0)))
+  (differentiable-wrapper
+   (list
+    (ewise2 dmax 0)
+    (ewise2 (flip dmax) 1))
+   (extend max)
+   x y))
+
+(define (i:min x y)
+  "Differentiable element-wise minimum."
+  (define (dmin x y)
+    (cond
+     ((< x y)
+      1)
+     ((= x y)
+      1/2)
+     (else
+      0)))
+  (differentiable-wrapper
+   (list
+    (ewise2 dmin 0)
+    (ewise2 (flip dmin) 1))
+   (extend min)
+   x y))
+
+(define (i:abs x)
+  "Differentiable absolute."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x)
+                  (cond ((> x 0)
+                         +1)
+                        ((= x 0)
+                         1/2)
+                        ((< x 0)
+                         -1)))))
+   (extend abs)
+   x))
+
+(define (mean x)
+  "Differentiable mean on arrays."
+  (differentiable-wrapper
+   (list
+    (lambda (xs i j one-over-n)
+      one-over-n))
+   (let ((n 0))
+     (list
+      (lambda (x)
+       (let ((sum 0))
+         (array-for-each
+          (lambda (x)
+            (set! sum (+ sum x))
+            (set! n (1+ n)))
+          x)
+         (/ sum n)))
+      (lambda _ (/ n))))
+   x))
+
+(define (i:array-ref x . indices)
+  "Differentiable array-ref w.r.t `x'."
+  (apply
+   differentiable-wrapper
+   (cons
+    (lambda (xs i j abs-index)
+      (if (= j abs-index)
+         1
+         0))
+    (map not indices))
+   (list
+    array-ref
+    (lambda (x . indices)
+      (rel->abs indices (array-dimensions x))))
+   x indices))
+
+(define (i:array-cell-ref x . indices)
+  (apply
+   differentiable-wrapper
+   (cons
+    (lambda (xs i j abs-index n-rst-dims)
+      (receive (j-ref j-rst) (euclidean/ j n-rst-dims)
+       (if (and (= j-ref abs-index)
+                (= j-rst i))
+           1
+           0)))
+    (map not indices))
+   (list
+    array-cell-ref
+    (lambda (x . indices)
+      (rel->abs indices (take (array-dimensions x)
+                             (length indices))))
+    (lambda (x . indices)
+      (apply * (drop (array-dimensions x)
+                    (length indices)))))
+   x indices))
+
+(define (make-batch elem . more)
+  (let ((batch-size (1+ (length more))))
+    (apply
+     differentiable-wrapper
+     (list-tabulate
+      batch-size
+      (lambda (b)
+       (lambda (xs i j n-rest-dims)
+         (receive (i-batch i-rest) (euclidean/ i n-rest-dims)
+           (if (and (= i-batch b)
+                    (= i-rest j))
+               1
+               0
+               )))))
+     (list
+      (lambda (elem . more)
+       (let ((a (apply make-typed-array *atype* *unspecified* batch-size
+                       (dims-of elem))))
+         (for-each
+          (lambda (x b)
+            (array-cell-set! a x b))
+          (cons elem more)
+          (list-tabulate batch-size identity))
+         a))
+      (lambda (elem . more)
+       (apply * (dims-of elem))))
+     elem more)))
+
+(define (maximum x)
+  "Differentiable maximum on arrays."
+  (differentiable-wrapper
+   (list
+    (lambda (xs i j max-index)
+      (if (= j max-index)
+         1
+         0)))
+   (let ((max-index 'TBD))
+     (list
+      (lambda (x)
+       (let ((m (- (inf)))
+             (i 0))
+         (array-for-each
+          (lambda (x)
+            (when (< m x)
+              (set! m x)
+              (set! max-index i))
+            (set! i (1+ i)))
+          x)
+         m))
+      (lambda _ max-index)))
+   x))
+
+(define (sum x)
+  "Differentiable sum on arrays."
+  (differentiable-wrapper
+   (list (lambda _ 1))
+   (lambda (x)
+     (let ((sum 0))
+       (array-for-each
+       (lambda (x)
+         (set! sum (+ sum x)))
+       x)
+       sum))
+   x))
+
+(define (adot x y n)
+  (differentiable-wrapper
+   (list
+    (lambda (xs i j n-free-dims-y n-bound-dims)
+      (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+       (receive (j-free j-bound) (euclidean/ j n-bound-dims)
+         (ifn (= i-x j-free)
+              0
+              (array-ref (array-contents (cadr xs))
+                         (+ i-y (* n-free-dims-y j-bound)))))))
+    (lambda (xs i j n-free-dims-y n-bound-dims)
+      (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+       (receive (j-bound j-free) (euclidean/ j n-free-dims-y)
+         (ifn (= i-y j-free)
+              0
+              (array-ref (array-contents (car xs))
+                         (+ j-bound (* n-bound-dims i-x)))))))
+    #f)
+   (list
+    contract-arrays
+    (lambda (x y n)
+      (apply * (drop (array-dimensions y)
+                    n)))
+    (lambda (x y n)
+      (apply * (take (array-dimensions y)
+                    n))))
+   x y n))
+
+(define (amap2 f x y)
+  (apply make-batch
+        (list-tabulate (car (dims-of (unwrap-fwd x)))
+                       (lambda (b)
+                         (f (i:array-cell-ref x b)
+                            (i:array-cell-ref y b))))))
diff --git a/grad-tests.scm b/grad-tests.scm
deleted file mode 100644 (file)
index bdf8fd4..0000000
+++ /dev/null
@@ -1,347 +0,0 @@
-(define-module (vouivre grad tests)
-  #:use-module ((vouivre grad) #:prefix v:)
-  #:use-module (ice-9 receive)
-  #:use-module (srfi srfi-1)
-  #:use-module (srfi srfi-64)
-  #:use-module (vouivre misc)
-  #:export
-  (apply-grad
-   apply-grad-amap2
-   a~
-   differentiable-func-generator
-   lambda-const-call
-   ngrad
-   n~
-   random-array
-   random-array-shape
-   random-func1
-   random-func2
-   random-func2-rank&dims>0
-   random-input
-   random-list-element
-   random-non-empty-array
-   random-shape
-   random-shared
-   random-shared-array-rank&dims>0
-   random-shared-contractible
-   with-generators
-   ~))
-
-(define f1s (list v:exp v:identity))
-(define f2s (list v:+ v:- v:* v:max v:min))
-
-(define-syntax-rule (with-generators (g1 g2 ...) equal expected given)
-  (let ((times 100)
-       (fx expected)
-       (fy given)
-       (generators (list g1 g2 ...)))
-    (call/cc
-     (lambda (break)
-       (do ((i 0 (1+ i)))
-          ((= i times) #t)
-        (let ((zs (map-in-order (lambda (g) (g)) generators)))
-          (with-exception-handler
-              (lambda (e)
-                (break #f zs))
-            (lambda ()
-              (let ((r1 (apply fx zs))
-                    (r2 (apply fy zs)))
-                (unless (equal r1 r2)
-                  (break #f zs r1 r2))))
-            #:unwind? #t)))))))
-
-(define (lambda-const-call f . consts)
-  (lambda _
-    (apply f consts)))
-
-(define* (random-array-shape
-         #:optional (min-rank 0) (max-rank 5) (min-dim 0) (max-dim 5))
-  (list-tabulate (+ min-rank (random (- max-rank min-rank)))
-                (lambda _ (+ min-dim (random (- max-dim min-dim))))))
-
-(define (random-shape)
-  (if (= 0 (random 2))
-      0
-      (random-array-shape)))
-
-(define* (random-array #:optional shape)
-  (apply produce-typed-array
-        (lambda _ (random:uniform))
-        v:*atype* (or shape (random-array-shape))))
-
-(define (random-non-empty-array)
-  "Random array of at least one element."
-  (random-array (random-array-shape 0 5 1 5)))
-
-(define* (random-input #:optional shape)
-  (let ((shape (or shape (random-shape))))
-    (if (eq? 0 shape)
-       (random:uniform)
-       (random-array shape))))
-
-(define (random-shared)
-  (let ((shape (random-shape)))
-    (values
-     (lambda ()
-       (random-input shape))
-     (lambda ()
-       (let ((x (random-input
-                (random-list-element
-                 (list 0 (if (list? shape)
-                             shape
-                             (random-shape)))))))
-        (set! shape (random-shape))
-        x)))))
-
-(define (random-shared-array-rank&dims>0)
-  (let ((shape (random-array-shape 1 5 1 5)))
-    (values
-     (lambda ()
-       (random-array shape))
-     (lambda ()
-       (let ((x (random-array shape)))
-        (set! shape (random-array-shape 1 5 1 5))
-        x)))))
-
-(define (random-list-element lst)
-  (list-ref lst (random (length lst))))
-
-(define (differentiable-func-generator lst . input-generators)
-  (lambda ()
-    (random-list-element
-     (cons
-      (apply
-       lambda-const-call
-       (random-list-element lst)
-       (map (lambda (g) (g))
-           input-generators))
-      lst))))
-
-(define random-func1
-  (differentiable-func-generator f1s random-input))
-(define random-func2
-  (receive (gx gy) (random-shared)
-    (differentiable-func-generator f2s gx gy)))
-
-(define* (n~ x y #:optional (error 1e-4))
-  (and
-   (>= y (- x error))
-   (<= y (+ x error))))
-
-(define* (a~ x y #:optional (error 1e-4))
- (and
-  (equal? (array-dimensions x)
-         (array-dimensions y))
-  (call/cc
-   (lambda (break)
-     (array-for-each
-      (lambda (x y)
-       (unless (~ x y error)
-         (break #f)))
-      x y)
-     #t))))
-
-(define* (~ x y #:optional (error 1e-4))
-  (cond
-   ((and (number? x) (number? y))
-    (n~ x y error))
-   ((and (array? x) (array? y))
-    (a~ x y error))
-   (else #f)))
-
-(define* (ngrad f #:optional (axis 0) (step 1e-6))
-  "Gradient using a numerical centered difference approximation."
-  (define (axis-add xs dh . indices)
-    "Add `dh' to the number or array at the given `axis' of `xs',
-and, when it's an array, at the given index."
-    (map-indexed
-     (lambda (x i)
-       (ifn (= i axis)
-           x
-           (if (number? x)
-               (+ x dh)
-               (array-map-indexed
-                (lambda (x . indices_)
-                  (ifn (equal? indices indices_)
-                       x
-                       (+ x dh)))
-                x))))
-     xs))
-  (lambda xs
-    ;; We need the output shape and the input shape along the
-    ;; differentiated axis.
-    (let ((fxs (apply f xs))
-         (x (list-ref xs axis)))
-      (cond
-       ((and (number? fxs)
-            (number? x))
-       (/ (- (apply f (axis-add xs step))
-             (apply f (axis-add xs (- step))))
-          (* 2 step)))
-       ((and (number? fxs)
-            (array? x))
-       (apply
-        produce-typed-array
-        (lambda indices
-          (/ (- (apply f (apply axis-add xs step indices))
-                (apply f (apply axis-add xs (- step) indices)))
-             (* 2 step)))
-        v:*atype*
-        (array-dimensions x)))
-       ((and (array? fxs)
-            (number? x))
-       ((v:extend /)
-        ((v:extend -)
-         (apply f (axis-add xs step))
-         (apply f (axis-add xs (- step))))
-        (* 2 step)))
-       ((and (array? fxs)
-            (array? x))
-       (let ((a (apply
-                 make-typed-array v:*atype* *unspecified*
-                 (append (array-dimensions fxs)
-                         (array-dimensions x)))))
-         (for-indices-in-range
-          (lambda indices-in
-            (let ((dfxs ((v:extend /)
-                         ((v:extend -)
-                          (apply f (apply axis-add xs step indices-in))
-                          (apply f (apply axis-add xs (- step) indices-in)))
-                         (* 2 step))))
-              (for-indices-in-range
-               (lambda indices-out
-                 (apply
-                  array-set!
-                  a
-                  (apply array-ref dfxs indices-out)
-                  (append indices-out indices-in)))
-               (list-zeros (array-rank fxs))
-               (array-dimensions fxs))))
-          (list-zeros (array-rank x))
-          (array-dimensions x))
-         a))))))
-
-(test-begin "grad")
-
-;; extended operations
-(test-assert (with-generators (random-input) ~ (v:extend identity) v:identity))
-(test-assert (with-generators (random-input) ~ (v:extend exp) v:exp))
-(test-assert
-    (receive (gx gy) (random-shared)
-      (with-generators (gx gy) ~ (v:extend *) v:*)))
-
-(define* (apply-grad grad-func #:optional (axis 0))
-  (lambda (f . args)
-    (apply (grad-func f axis) args)))
-
-(test-assert
-    (with-generators
-     (random-func1 random-input)
-     ~ (apply-grad ngrad) (apply-grad v:grad)))
-
-;; `v:mean' only takes non-empty arrays so we treat it separately
-(test-assert
-    (with-generators
-     ((differentiable-func-generator (list v:mean) random-non-empty-array)
-      random-non-empty-array)
-     ~ (apply-grad ngrad) (apply-grad v:grad)))
-
-(test-assert
-    (receive (gx gy) (random-shared)
-      (with-generators
-       (random-func2 gx gy)
-       ~ (apply-grad ngrad 0) (apply-grad v:grad 0))))
-(test-assert
-    (receive (gx gy) (random-shared)
-      (with-generators
-       (random-func2 gx gy)
-       ~ (apply-grad ngrad 1) (apply-grad v:grad 1))))
-
-;; `v:amap2' only takes arrays of rank > 0 and batch-size > 0 so we treat it
-;; separately
-(define* (apply-grad-amap2 grad-func #:optional (axis 0))
-  (lambda (f x y)
-    ((grad-func
-      (lambda (x y)
-       (v:amap2 f x y))
-      axis)
-     x y)))
-(define random-func2-rank&dims>0
-  (receive (gx gy) (random-shared-array-rank&dims>0)
-    (differentiable-func-generator f2s gx gy)))
-(test-assert
-    (receive (gx gy) (random-shared-array-rank&dims>0)
-      (with-generators
-       (random-func2-rank&dims>0 gx gy)
-       ~ (apply-grad-amap2 ngrad 0) (apply-grad-amap2 v:grad 0))))
-(test-assert
-    (receive (gx gy) (random-shared-array-rank&dims>0)
-      (with-generators
-       (random-func2-rank&dims>0 gx gy)
-       ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1))))
-(test-assert
-    (~ ((ngrad v:amap2 1) v:* #(1 2 3) #(10 20 30))
-       ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30))))
-(test-assert
-    (let ((x #(10 20 30))
-         (y #(10 20 30)))
-      (~ ((ngrad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))
-        ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)))))
-
-;; `v:adot'
-(define (random-shared-contractible)
-  "Returns three generators: the first two generate arrays that are contractible
-according to the number generated by the third one."
-  (let* ((n (random 5))
-        (sa (random-array-shape n))
-        (sb (append (reverse (take (reverse sa)
-                                   n))
-                    (random-array-shape 0 (- 5 n)))))
-    (values
-     (lambda ()
-       (random-array sa))
-     (lambda ()
-       (random-array sb))
-     (lambda ()
-       (let ((tmp n))
-        (set! n (random 5))
-        (set! sa (random-array-shape n))
-        (set! sb (append (reverse (take (reverse sa)
-                                         n))
-                         (random-array-shape 0 (- 5 n))))
-        tmp)))))
-(test-assert
-    (receive (gx gy gz) (random-shared-contractible)
-      (with-generators
-       (gx gy gz)
-       ~
-       (lambda (a b n) ((ngrad v:adot 0) a b n))
-       (lambda (a b n) ((v:grad v:adot 0) a b n)))))
-(test-assert
-    (receive (gx gy gz) (random-shared-contractible)
-      (with-generators
-       (gx gy gz)
-       ~
-       (lambda (a b n) ((ngrad v:adot 1) a b n))
-       (lambda (a b n) ((v:grad v:adot 1) a b n)))))
-
-;; chain rule
-(test-assert
-    (with-generators
-     (random-func1 random-func1 random-input)
-     ~
-     (lambda (f g x)
-       (let* ((gx (g x))
-             (r
-              (v:dot ((v:grad f) gx)
-                     ((v:grad g) x)
-                     (v:rank-of gx))))
-        (ifn (and (number? (f (g x)))
-                  (number? x)
-                  (array? r))
-             r
-             (array-ref r))))
-     (lambda (f g x)
-       ((v:grad (compose f g)) x))))
-
-(test-end "grad")
diff --git a/grad.scm b/grad.scm
deleted file mode 100644 (file)
index 9e4ec76..0000000
--- a/grad.scm
+++ /dev/null
@@ -1,699 +0,0 @@
-(define-module (vouivre grad)
-  #:use-module (ice-9 receive)
-  #:use-module (srfi srfi-1)
-  #:use-module (srfi srfi-9)
-  #:use-module (vouivre misc)
-  #:use-module (vouivre promises)
-  #:export
-  (*atype*
-   adot
-   amap2
-   contract-arrays
-   differentiable-wrapper
-   dot
-   do-times
-   ewise1
-   ewise2
-   extend
-   grad
-   make-batch
-   make-internal
-   maximum
-   mean
-   rank-of
-   sum)
-  #:replace
-  ((i:sqrt . sqrt)
-   (i:exp . exp)
-   (i:expt . expt)
-   (i:log . log)
-   (i:sin . sin)
-   (i:cos . cos)
-   (i:tan . tan)
-   (i:+ . +)
-   (i:- . -)
-   (i:* . *)
-   (i:/ . /)
-   (i:max . max)
-   (i:min . min)
-   (i:abs . abs)
-   (i:identity . identity)
-   (i:array-ref . array-ref)
-   (i:array-cell-ref . array-cell-ref))
-  #:re-export
-  (fold
-   reduce))
-
-;;;; array utilities
-
-(define (rel->abs indices dimensions)
-  (let rec ((s 0)
-           (p 1)
-           (is (reverse indices))
-           (ds (reverse dimensions)))
-    (if (null? is)
-       s
-       (rec (+ s (* p (car is)))
-            (* p (car ds))
-            (cdr is)
-            (cdr ds)))))
-
-(define(do-times n proc)
-  (let rec ((i 0))
-    (unless (= i n)
-      (proc i)
-      (rec (1+ i)))))
-
-(define (contract-arrays a b n)
-  (let* ((dims-a (array-dimensions a))
-        (dims-b (array-dimensions b))
-        (free-dims-a (take dims-a (- (array-rank a) n)))
-        (free-dims-b (drop dims-b n))
-        (bound-dims (take dims-b n))
-        (n-free-dims-a (apply * free-dims-a))
-        (n-free-dims-b (apply * free-dims-b))
-        (n-bound-dims (apply * bound-dims))
-        (s 0)
-        (r (apply make-typed-array *atype* *unspecified* (append free-dims-a
-                                                                 free-dims-b)))
-        (ac (array-contents a))
-        (bc (array-contents b))
-        (rc (array-contents r)))
-    (do-times
-     n-free-dims-a
-     (lambda (i)
-       (let ((i-k (* n-bound-dims i))
-            (i-j (* n-free-dims-b i)))
-        (do-times
-         n-free-dims-b
-         (lambda (j)
-           (set! s 0)
-           (do-times
-            n-bound-dims
-            (lambda (k)
-              (set! s (+ s (* (array-ref ac (+ i-k k))
-                              (array-ref bc (+ (* n-free-dims-b k) j)))))))
-           (array-set! rc s (+ i-j j)))))))
-    r))
-
-;;;; utilities that work on both numbers and arrays
-
-(define (extend f)
-  "Extend a function of one or more scalars to apply to numbers/arrays
-element-wise. All arrays must have the same dimension."
-  (define (apply-elemwise f indices args)
-    (apply f (map (lambda (x)
-                   (if (number? x)
-                       x
-                       (apply array-ref x indices)))
-                 args)))
-  (lambda xs
-    (if-let (x (find array? xs))
-           (apply
-            produce-typed-array
-            (lambda is
-              (apply-elemwise f is xs))
-            *atype*
-            (array-dimensions x))
-           (apply f xs))))
-
-(define (dot x y n)
-  (cond
-   ((and (number? x) (number? y))
-    (* x y))
-   ((and (array? x) (array? y))
-    (contract-arrays x y n))
-   ((and (array? x) (number? y))
-    ((extend *) x y))
-   ((and (number? x) (array? y))
-    ((extend *) x y))
-   (else (error "can't dot because of invalid types or ranks" x y n))))
-
-(define (rank-of x)
-  (if (number? x)
-      0
-      (array-rank x)))
-
-;;;; differentiation
-
-(define-record-type internal
-  (make-internal forward jacobian)
-  internal?
-  (forward internal-forward)
-  (jacobian internal-jacobian))
-
-;;(define *atype* 'f32)
-(define *atype* #t)
-(define *grad* (make-parameter #f))
-(define *j* (make-parameter #f))
-
-(define-syntax-rule (w/j val body ...)
-  (parameterize ((*j* val))
-    body ...))
-
-(define (unwrap-fwd x)
-  (if (internal? x)
-      (internal-forward x)
-      x))
-
-(define (unwrap-jac x)
-  (if (internal? x)
-      (internal-jacobian x)
-      x))
-
-(define (dims-of x)
-  (if (number? x)
-      '()
-      (array-dimensions x)))
-
-(define (mov dst-buf n-dst-dims generator naked-inputs data j)
-  (do-times
-   n-dst-dims
-   (lambda (i)
-     (array-set!
-      dst-buf
-      (apply generator naked-inputs i j data)
-      i)))
-  dst-buf)
-
-(define (add dst-buf n-dst-dims generator naked-inputs data j)
-  (do-times
-   n-dst-dims
-   (lambda (i)
-     (array-set!
-      dst-buf
-      (+ (array-ref dst-buf i)
-        (apply generator naked-inputs i j data))
-      i)))
-  dst-buf)
-
-(define (movc dst-buf n-dst-dims src-buf n-src-dims
-             generator naked-inputs data)
-  "Contract the Jacobian column produced by the generator with the source buffer
-storing the result in the destination buffer."
-  (let ((s 0))
-    (do-times
-     n-dst-dims
-     (lambda (i)
-       (set! s 0)
-       (do-times
-       n-src-dims
-       (lambda (k)
-         (set! s (+ s (* (apply generator naked-inputs i k data)
-                         (array-ref src-buf k))))))
-       (array-set! dst-buf s i))))
-  dst-buf)
-
-(define (addc dst-buf n-dst-dims src-buf n-src-dims
-             generator naked-inputs data)
-  "Contract the Jacobian column produced by the generator with the source buffer
-adding the result to the destination buffer."
-  (let ((s 0))
-    (do-times
-     n-dst-dims
-     (lambda (i)
-       (set! s (array-ref dst-buf i))
-       (do-times
-       n-src-dims
-       (lambda (k)
-         (set! s (+ s (* (apply generator naked-inputs i k data)
-                         (array-ref src-buf k))))))
-       (array-set! dst-buf s i))))
-  dst-buf)
-
-(define* (grad f #:optional (axis 0))
-  (define (wrap x i)
-    (if (= i axis)
-       (make-internal x 'input)
-       x))
-  (lambda xs
-    (parameterize (((@@ (vouivre grad) *grad*) #t)
-                  ((@@ (vouivre grad) *promises*) (cons '() #f)))
-      (let* ((internal (apply f (map-indexed wrap xs)))
-            (fx (internal-forward internal))
-            (y (list-ref xs axis))  ; variable to differentiate w.r.t
-            (pre-Jx (internal-jacobian internal))
-            (Jx (cond
-                 ;; TODO: implement 'input case and test 'zero and 'input
-                 ((eq? pre-Jx 'zero)
-                  (lambda (j)
-                    (lambda (i)
-                      0)))
-                 ((eq? pre-Jx 'input)
-                  (error "TBD."))
-                 (else
-                  (lambda (j)
-                    (reset-promises (car (*promises*)))
-                    (let ((column-jac (w/j j (force pre-Jx))))
-                      (lambda (i)
-                        (array-ref column-jac i))))))))
-       (cond
-        ((and (number? fx) (number? y))
-         ((Jx 0) 0))
-        ((and (number? fx) (array? y))
-         (let* ((y-dims (array-dimensions y))
-                (a (apply make-array *unspecified* y-dims))
-                (ac (array-contents a)))
-           (do-times
-            (apply * y-dims)
-            (lambda (j)
-              (array-set! ac ((Jx j) 0)
-                          j)))
-           a))
-        ((and (array? fx) (number? y))
-         (let* ((fx-dims (array-dimensions fx))
-                (a (apply make-array *unspecified* fx-dims))
-                (ac (array-contents a))
-                (Jx (Jx 0)))
-           (do-times
-            (apply * fx-dims)
-            (lambda (i)
-              (array-set! ac (Jx i)
-                          i)))
-           a))
-        (else
-         (let* ((fx-dims (array-dimensions fx))
-                (y-dims (array-dimensions y))
-                (n-fx-dims (apply * fx-dims))
-                (n-y-dims (apply * y-dims))
-                (a (apply make-array *unspecified* (append fx-dims y-dims)))
-                (ac (array-contents a)))
-           (do-times
-            n-y-dims
-            (lambda (j)
-              (let ((Jx (Jx j)))
-                (do-times
-                 n-fx-dims
-                 (lambda (i)
-                   (array-set! ac (Jx i)
-                               (+ j (* n-y-dims i))))))))
-           a)))))))
-
-
-;; In the comment that follows:
-;;
-;; `n' is the number of arguments to `proc'.
-;; `generators is not a `Vec' but a `List' we only use the former to illustrate
-;; its length.
-;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
-;; `I' is the type of multi-indices indexing the output of `function'.
-;; `J' is the type of multi-indices indexing the input array being differentiated.
-;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J').
-;; `Array I' is the type of arrays indexed by multi-indices of `I'.
-;; `[X]' means that `X' is boxed in an internal as when returned by
-;;       `differentiable-wrapper' with the array being `X' and the promise that
-;;       given a |J| we will get the change of `X' with a change of the
-;;       the differentiated argument at multi-index `J'.
-;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number))
-;;       (→ X1 ... Xn (Array I))*
-;;       [X1] ... [Xn]
-;;       (Internal (Array I) (Promise |J| (Array |I|)))))
-;;
-;; (*) We extend this definition to allow `proc' to be a list of procedures
-;;     the head of which is as described above and the remaining elements
-;;     are procedures of the same arguments but returning values that are
-;;     then fed as extra data to the generators.
-;;
-;; NOTE: In cases where an argument isn't meant to be differentiable its
-;;       corresponding generator should be `#f'.
-(define (differentiable-wrapper generators proc* arg . more)
-  (let* ((args (cons arg more))
-        (proc (if (procedure? proc*)
-                  proc*
-                  (car proc*)))
-        (naked-args (map unwrap-fwd args))
-        (out (apply proc naked-args)))
-    (ifn (*grad*)
-        out
-        (let* ((data (if (procedure? proc*)
-                         '()
-                         (map (lambda (g)
-                                (apply g naked-args))
-                              (cdr proc*))))
-               (n-out-dims (apply * (dims-of out)))
-               (buf (make-array *unspecified* n-out-dims)))
-          (make-internal
-           out
-           (fold
-            (lambda (generator arg prev)
-              (if (or (not (internal? arg))
-                      (eq? 'zero (internal-jacobian arg)))
-                  prev
-                  (let ((Jx (internal-jacobian arg))
-                        (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
-                    (if (eq? Jx 'input)
-                        (if (eq? prev 'zero)
-                            (delay
-                              (mov buf n-out-dims
-                                   generator naked-args data (*j*)))
-                            (delay
-                              (add (force prev) n-out-dims
-                                   generator naked-args data (*j*))))
-                        (if (eq? prev 'zero)
-                            (delay
-                              (movc buf n-out-dims (force Jx) n-fwd-dims
-                                    generator naked-args data))
-                            (delay
-                              (addc (force prev) n-out-dims
-                                    (force Jx) n-fwd-dims
-                                    generator naked-args data)))))))
-            'zero generators args))))))
-
-(define (ewise1 f)
-  (lambda (xs i j)
-    (let ((x (car xs)))
-      (if (number? x)
-         (f x)
-         (ifn (= i j)
-              0
-              (f (array-ref (array-contents x)
-                            j)))))))
-
-(define (ewise2 proc axis)
-  (lambda (xs i j)
-    (let ((x (car xs))
-         (y (cadr xs)))
-      (cond
-       ((and (number? x) (number? y))
-       (proc x y))
-       ((and (number? x) (array? y))
-       (if (= axis 0)
-           (proc x (array-ref (array-contents y)
-                              i))
-           (ifn (= i j)
-                0
-                (proc x (array-ref (array-contents y)
-                                   j)))))
-       ((and (array? x) (number? y))
-       (if (= axis 1)
-           (proc (array-ref (array-contents x)
-                            i)
-                 y)
-           (ifn (= i j)
-                0
-                (proc (array-ref (array-contents x)
-                                 j)
-                      y))))
-       (else
-       (ifn (= i j)
-            0
-            (proc (array-ref (array-contents x)
-                             j)
-                  (array-ref (array-contents y)
-                             j))))))))
-
-(define (i:identity x)
-  "Differentiable identity."
-  (differentiable-wrapper
-   (list (ewise1 (lambda _ 1)))
-   identity
-   x))
-
-(define (i:sqrt x)
-  "Differentiable square root."
-  (differentiable-wrapper
-   (list (ewise1 (lambda (x) (/ 1 2 (sqrt x)))))
-   (extend sqrt)
-   x))
-
-(define (i:exp x)
-  "Differentiable exponential."
-  (differentiable-wrapper
-   (list (ewise1 exp))
-   (extend exp)
-   x))
-
-(define (i:expt x y)
-  "Differentiable power."
-  (differentiable-wrapper
-   (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0)
-        (ewise2 (lambda (x y) (* (expt x y) (log x))) 1))
-   (extend expt)
-   x y))
-
-(define (i:log x)
-  "Differentiable logarithm."
-  (differentiable-wrapper
-   (list (ewise1 (lambda (x) (/ x))))
-   (extend log)
-   x))
-
-(define (i:sin x)
-  "Differentiable sine."
-  (differentiable-wrapper
-   (list (ewise1 cos))
-   (extend sin)
-   x))
-
-(define (i:cos x)
-  "Differentiable cosine."
-  (differentiable-wrapper
-   (list (ewise1 (lambda (x) (- (sin x)))))
-   (extend cos)
-   x))
-
-(define (i:tan x)
-  "Differentiable tangent."
-  (differentiable-wrapper
-   (list (ewise1 (lambda (x) (/ (expt (cos x) 2)))))
-   (extend tan)
-   x))
-
-(define (i:+ x y)
-  "Differentiable element-wise addition."
-  (differentiable-wrapper
-   (list
-    (ewise2 (lambda _ +1) 0)
-    (ewise2 (lambda _ +1) 1))
-   (extend +)
-   x y))
-
-(define (i:- x y)
-  "Differentiable element-wise subtraction."
-  (differentiable-wrapper
-   (list
-    (ewise2 (lambda _ +1) 0)
-    (ewise2 (lambda _ -1) 1))
-   (extend -)
-   x y))
-
-(define (i:* x y)
-  "Differentiable element-wise multiplication."
-  (differentiable-wrapper
-   (list
-    (ewise2 (lambda (x y) y) 0)
-    (ewise2 (lambda (x y) x) 1))
-   (extend *)
-   x y))
-
-(define (i:/ x y)
-  "Differentiable element-wise division."
-  (differentiable-wrapper
-   (list
-    (ewise2 (lambda (x y) (/ y)) 0)
-    (ewise2 (lambda (x y) (- (/ x y y))) 1))
-   (extend /)
-   x y))
-
-(define (i:max x y)
-  "Differentiable element-wise maximum."
-  (define (dmax x y)
-    (cond
-     ((> x y)
-      1)
-     ((= x y)
-      1/2)
-     (else
-      0)))
-  (differentiable-wrapper
-   (list
-    (ewise2 dmax 0)
-    (ewise2 (flip dmax) 1))
-   (extend max)
-   x y))
-
-(define (i:min x y)
-  "Differentiable element-wise minimum."
-  (define (dmin x y)
-    (cond
-     ((< x y)
-      1)
-     ((= x y)
-      1/2)
-     (else
-      0)))
-  (differentiable-wrapper
-   (list
-    (ewise2 dmin 0)
-    (ewise2 (flip dmin) 1))
-   (extend min)
-   x y))
-
-(define (i:abs x)
-  "Differentiable absolute."
-  (differentiable-wrapper
-   (list (ewise1 (lambda (x)
-                  (cond ((> x 0)
-                         +1)
-                        ((= x 0)
-                         1/2)
-                        ((< x 0)
-                         -1)))))
-   (extend abs)
-   x))
-
-(define (mean x)
-  "Differentiable mean on arrays."
-  (differentiable-wrapper
-   (list
-    (lambda (xs i j one-over-n)
-      one-over-n))
-   (let ((n 0))
-     (list
-      (lambda (x)
-       (let ((sum 0))
-         (array-for-each
-          (lambda (x)
-            (set! sum (+ sum x))
-            (set! n (1+ n)))
-          x)
-         (/ sum n)))
-      (lambda _ (/ n))))
-   x))
-
-(define (i:array-ref x . indices)
-  "Differentiable array-ref w.r.t `x'."
-  (apply
-   differentiable-wrapper
-   (cons
-    (lambda (xs i j abs-index)
-      (if (= j abs-index)
-         1
-         0))
-    (map not indices))
-   (list
-    array-ref
-    (lambda (x . indices)
-      (rel->abs indices (array-dimensions x))))
-   x indices))
-
-(define (i:array-cell-ref x . indices)
-  (apply
-   differentiable-wrapper
-   (cons
-    (lambda (xs i j abs-index n-rst-dims)
-      (receive (j-ref j-rst) (euclidean/ j n-rst-dims)
-       (if (and (= j-ref abs-index)
-                (= j-rst i))
-           1
-           0)))
-    (map not indices))
-   (list
-    array-cell-ref
-    (lambda (x . indices)
-      (rel->abs indices (take (array-dimensions x)
-                             (length indices))))
-    (lambda (x . indices)
-      (apply * (drop (array-dimensions x)
-                    (length indices)))))
-   x indices))
-
-(define (make-batch elem . more)
-  (let ((batch-size (1+ (length more))))
-    (apply
-     differentiable-wrapper
-     (list-tabulate
-      batch-size
-      (lambda (b)
-       (lambda (xs i j n-rest-dims)
-         (receive (i-batch i-rest) (euclidean/ i n-rest-dims)
-           (if (and (= i-batch b)
-                    (= i-rest j))
-               1
-               0
-               )))))
-     (list
-      (lambda (elem . more)
-       (let ((a (apply make-typed-array *atype* *unspecified* batch-size
-                       (dims-of elem))))
-         (for-each
-          (lambda (x b)
-            (array-cell-set! a x b))
-          (cons elem more)
-          (list-tabulate batch-size identity))
-         a))
-      (lambda (elem . more)
-       (apply * (dims-of elem))))
-     elem more)))
-
-(define (maximum x)
-  "Differentiable maximum on arrays."
-  (differentiable-wrapper
-   (list
-    (lambda (xs i j max-index)
-      (if (= j max-index)
-         1
-         0)))
-   (let ((max-index 'TBD))
-     (list
-      (lambda (x)
-       (let ((m (- (inf)))
-             (i 0))
-         (array-for-each
-          (lambda (x)
-            (when (< m x)
-              (set! m x)
-              (set! max-index i))
-            (set! i (1+ i)))
-          x)
-         m))
-      (lambda _ max-index)))
-   x))
-
-(define (sum x)
-  "Differentiable sum on arrays."
-  (differentiable-wrapper
-   (list (lambda _ 1))
-   (lambda (x)
-     (let ((sum 0))
-       (array-for-each
-       (lambda (x)
-         (set! sum (+ sum x)))
-       x)
-       sum))
-   x))
-
-(define (adot x y n)
-  (differentiable-wrapper
-   (list
-    (lambda (xs i j n-free-dims-y n-bound-dims)
-      (receive (i-x i-y) (euclidean/ i n-free-dims-y)
-       (receive (j-free j-bound) (euclidean/ j n-bound-dims)
-         (ifn (= i-x j-free)
-              0
-              (array-ref (array-contents (cadr xs))
-                         (+ i-y (* n-free-dims-y j-bound)))))))
-    (lambda (xs i j n-free-dims-y n-bound-dims)
-      (receive (i-x i-y) (euclidean/ i n-free-dims-y)
-       (receive (j-bound j-free) (euclidean/ j n-free-dims-y)
-         (ifn (= i-y j-free)
-              0
-              (array-ref (array-contents (car xs))
-                         (+ j-bound (* n-bound-dims i-x)))))))
-    #f)
-   (list
-    contract-arrays
-    (lambda (x y n)
-      (apply * (drop (array-dimensions y)
-                    n)))
-    (lambda (x y n)
-      (apply * (take (array-dimensions y)
-                    n))))
-   x y n))
-
-(define (amap2 f x y)
-  (apply make-batch
-        (list-tabulate (car (dims-of (unwrap-fwd x)))
-                       (lambda (b)
-                         (f (i:array-cell-ref x b)
-                            (i:array-cell-ref y b))))))