]> git.vouivredigital.com Git - vouivre.git/commitdiff
Optimize garbage collection
authoradmin <admin@vouivredigital.com>
Mon, 20 Nov 2023 07:56:37 +0000 (16:56 +0900)
committeradmin <admin@vouivredigital.com>
Mon, 20 Nov 2023 07:56:37 +0000 (16:56 +0900)
Instead of allocating memory on every index of every
differentiable function call we do it once per call
and use the buffer for all indices.

grad.scm
misc.scm

index f0aa3c24afc4ce1d4780140b801e074f3d9802e5..9e4ec764797d1dfa167956097cee48a921135161 100644 (file)
--- a/grad.scm
+++ b/grad.scm
@@ -3,6 +3,7 @@
   #:use-module (srfi srfi-1)
   #:use-module (srfi srfi-9)
   #:use-module (vouivre misc)
+  #:use-module (vouivre promises)
   #:export
   (*atype*
    adot
    contract-arrays
    differentiable-wrapper
    dot
+   do-times
+   ewise1
+   ewise2
    extend
    grad
-   internal-jacobian
+   make-batch
+   make-internal
    maximum
    mean
    rank-of
-   sum
-   )
+   sum)
   #:replace
   ((i:sqrt . sqrt)
    (i:exp . exp)
    (i:abs . abs)
    (i:identity . identity)
    (i:array-ref . array-ref)
-   )
+   (i:array-cell-ref . array-cell-ref))
   #:re-export
   (fold
    reduce))
 
 ;;;; array utilities
 
-(define (abs->rel index dimensions)
-  (let rec ((ds (cdr dimensions))
-           (p (apply * (cdr dimensions)))
-           (r index)
-           (is '()))
-    (let ((i (quotient r p)))
-      (if (null? ds)
-         (reverse (cons i is))
-         (rec (cdr ds)
-              (quotient p (car ds))
-              (- r (* p i))
-              (cons i is))))))
-
-(define (contracted-dims a b n)
-  (let ((dims-a (array-dimensions a))
-        (dims-b (array-dimensions b)))
-    (if (or (> n (array-rank a))
-            (> n (array-rank b)))
-      (error "can't contract arrays with size lower than" n)
-      (append (take dims-a (- (array-rank a)
-                             n))
-             (drop dims-b n)))))
+(define (rel->abs indices dimensions)
+  (let rec ((s 0)
+           (p 1)
+           (is (reverse indices))
+           (ds (reverse dimensions)))
+    (if (null? is)
+       s
+       (rec (+ s (* p (car is)))
+            (* p (car ds))
+            (cdr is)
+            (cdr ds)))))
 
 (define(do-times n proc)
   (let rec ((i 0))
@@ -152,176 +145,232 @@ element-wise. All arrays must have the same dimension."
 ;;(define *atype* 'f32)
 (define *atype* #t)
 (define *grad* (make-parameter #f))
+(define *j* (make-parameter #f))
+
+(define-syntax-rule (w/j val body ...)
+  (parameterize ((*j* val))
+    body ...))
+
+(define (unwrap-fwd x)
+  (if (internal? x)
+      (internal-forward x)
+      x))
+
+(define (unwrap-jac x)
+  (if (internal? x)
+      (internal-jacobian x)
+      x))
+
+(define (dims-of x)
+  (if (number? x)
+      '()
+      (array-dimensions x)))
+
+(define (mov dst-buf n-dst-dims generator naked-inputs data j)
+  (do-times
+   n-dst-dims
+   (lambda (i)
+     (array-set!
+      dst-buf
+      (apply generator naked-inputs i j data)
+      i)))
+  dst-buf)
+
+(define (add dst-buf n-dst-dims generator naked-inputs data j)
+  (do-times
+   n-dst-dims
+   (lambda (i)
+     (array-set!
+      dst-buf
+      (+ (array-ref dst-buf i)
+        (apply generator naked-inputs i j data))
+      i)))
+  dst-buf)
+
+(define (movc dst-buf n-dst-dims src-buf n-src-dims
+             generator naked-inputs data)
+  "Contract the Jacobian column produced by the generator with the source buffer
+storing the result in the destination buffer."
+  (let ((s 0))
+    (do-times
+     n-dst-dims
+     (lambda (i)
+       (set! s 0)
+       (do-times
+       n-src-dims
+       (lambda (k)
+         (set! s (+ s (* (apply generator naked-inputs i k data)
+                         (array-ref src-buf k))))))
+       (array-set! dst-buf s i))))
+  dst-buf)
+
+(define (addc dst-buf n-dst-dims src-buf n-src-dims
+             generator naked-inputs data)
+  "Contract the Jacobian column produced by the generator with the source buffer
+adding the result to the destination buffer."
+  (let ((s 0))
+    (do-times
+     n-dst-dims
+     (lambda (i)
+       (set! s (array-ref dst-buf i))
+       (do-times
+       n-src-dims
+       (lambda (k)
+         (set! s (+ s (* (apply generator naked-inputs i k data)
+                         (array-ref src-buf k))))))
+       (array-set! dst-buf s i))))
+  dst-buf)
 
 (define* (grad f #:optional (axis 0))
   (define (wrap x i)
     (if (= i axis)
-       (make-internal x #f)
+       (make-internal x 'input)
        x))
   (lambda xs
-    (let* ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity)))
-          (y (internal-forward (list-ref wrapped-xs axis))))
-      ((if (*grad*)
-          identity
-          (lambda (boxed)
-            (let ((fx (internal-forward boxed))
-                  (gfx (internal-jacobian boxed)))
-              (cond
-               ((and (number? fx) (number? y))
-                (gfx '()))
-               ((and (number? fx) (array? y))
-                (apply produce-typed-array
-                       (lambda js (gfx js))
-                       *atype*
-                       (array-dimensions y)))
-               ((and (array? fx) (number? y))
-                (gfx '()))
-               (else
-                (let* ((dims-out (array-dimensions fx))
-                       (dims-in (array-dimensions y))
-                       (a (apply make-typed-array *atype* *unspecified* (append dims-out dims-in))))
-                  (for-indices-in-range
-                   (lambda js
-                     (let ((out (gfx js)))
-                       (for-indices-in-range
-                        (lambda is
-                          (apply
-                           array-set!
-                           a
-                           (apply array-ref out is)
-                           (append is js)))
-                        (list-zeros (array-rank fx))
-                        dims-out)))
-                   (list-zeros (array-rank y))
-                   dims-in)
-                  a))))))
-       (parameterize (((@@ (vouivre grad) *grad*) y))
-        (apply f wrapped-xs))))))
-
-(define (unbox-fwd x)
-  (if (internal? x)
-      (internal-forward x)
-      x))
-
-;; `n' is the number of arguments to `function'.
-;; `jacobian-generators is not a `Vec' but a `List' we only use the former to
-;; show its length.
+    (parameterize (((@@ (vouivre grad) *grad*) #t)
+                  ((@@ (vouivre grad) *promises*) (cons '() #f)))
+      (let* ((internal (apply f (map-indexed wrap xs)))
+            (fx (internal-forward internal))
+            (y (list-ref xs axis))  ; variable to differentiate w.r.t
+            (pre-Jx (internal-jacobian internal))
+            (Jx (cond
+                 ;; TODO: implement 'input case and test 'zero and 'input
+                 ((eq? pre-Jx 'zero)
+                  (lambda (j)
+                    (lambda (i)
+                      0)))
+                 ((eq? pre-Jx 'input)
+                  (error "TBD."))
+                 (else
+                  (lambda (j)
+                    (reset-promises (car (*promises*)))
+                    (let ((column-jac (w/j j (force pre-Jx))))
+                      (lambda (i)
+                        (array-ref column-jac i))))))))
+       (cond
+        ((and (number? fx) (number? y))
+         ((Jx 0) 0))
+        ((and (number? fx) (array? y))
+         (let* ((y-dims (array-dimensions y))
+                (a (apply make-array *unspecified* y-dims))
+                (ac (array-contents a)))
+           (do-times
+            (apply * y-dims)
+            (lambda (j)
+              (array-set! ac ((Jx j) 0)
+                          j)))
+           a))
+        ((and (array? fx) (number? y))
+         (let* ((fx-dims (array-dimensions fx))
+                (a (apply make-array *unspecified* fx-dims))
+                (ac (array-contents a))
+                (Jx (Jx 0)))
+           (do-times
+            (apply * fx-dims)
+            (lambda (i)
+              (array-set! ac (Jx i)
+                          i)))
+           a))
+        (else
+         (let* ((fx-dims (array-dimensions fx))
+                (y-dims (array-dimensions y))
+                (n-fx-dims (apply * fx-dims))
+                (n-y-dims (apply * y-dims))
+                (a (apply make-array *unspecified* (append fx-dims y-dims)))
+                (ac (array-contents a)))
+           (do-times
+            n-y-dims
+            (lambda (j)
+              (let ((Jx (Jx j)))
+                (do-times
+                 n-fx-dims
+                 (lambda (i)
+                   (array-set! ac (Jx i)
+                               (+ j (* n-y-dims i))))))))
+           a)))))))
+
+
+;; In the comment that follows:
+;;
+;; `n' is the number of arguments to `proc'.
+;; `generators is not a `Vec' but a `List' we only use the former to illustrate
+;; its length.
 ;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
 ;; `I' is the type of multi-indices indexing the output of `function'.
 ;; `J' is the type of multi-indices indexing the input array being differentiated.
+;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J').
 ;; `Array I' is the type of arrays indexed by multi-indices of `I'.
 ;; `[X]' means that `X' is boxed in an internal as when returned by
-;;       `differentiable-wrapper' with the array being `X' and the function
-;;       taking a multi-index of `J' to return an array of the same shape as `X'.
-;; (∷ (→ (Vec n (→ X1 ... Xn I J Number))
-;;       (→ X1 ... Xn (Array I))
+;;       `differentiable-wrapper' with the array being `X' and the promise that
+;;       given a |J| we will get the change of `X' with a change of the
+;;       the differentiated argument at multi-index `J'.
+;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number))
+;;       (→ X1 ... Xn (Array I))*
 ;;       [X1] ... [Xn]
-;;       (Internal (Array I) (→ J (Array I)))))
-(define (differentiable-wrapper jacobian-generators function input . more)
-  ;; NOTE: Both the jacobian generators and the function act on naked inputs
-  ;;       (numbers or arrays not inside an internal object). The generators
-  ;;       additionaly take indices -- one for each dimension of what we are
-  ;;       differentiating with respect to (the content of `(*grad*)'). A
-  ;;       generator returns a jacobian column expressing the change of the
-  ;;       function's output when changing `(*grad*)' along the given indices.
-  ;;       In cases where an argument isn't meant to be differentiable its
-  ;;       corresponding generator should be `#f'.
-  (let* ((inputs (cons input more))
-        (naked-inputs (map unbox-fwd inputs))
-        (output (apply (if (procedure? function)
-                           function
-                           (car function))
-                       naked-inputs))
-        (data (if (procedure? function)
-                  '()
-                  (map (lambda (f) (apply f naked-inputs))
-                       (cdr function)))))
+;;       (Internal (Array I) (Promise |J| (Array |I|)))))
+;;
+;; (*) We extend this definition to allow `proc' to be a list of procedures
+;;     the head of which is as described above and the remaining elements
+;;     are procedures of the same arguments but returning values that are
+;;     then fed as extra data to the generators.
+;;
+;; NOTE: In cases where an argument isn't meant to be differentiable its
+;;       corresponding generator should be `#f'.
+(define (differentiable-wrapper generators proc* arg . more)
+  (let* ((args (cons arg more))
+        (proc (if (procedure? proc*)
+                  proc*
+                  (car proc*)))
+        (naked-args (map unwrap-fwd args))
+        (out (apply proc naked-args)))
     (ifn (*grad*)
-        output
-        (make-internal
-         output
-         (lambda (js)
-           (or
-            (fold
-             (lambda (jacobian-generator input prev)
-               (ifn (internal? input)
-                    prev
-                    ((ifn prev identity (lambda (x) ((extend +) prev x)))
-                     (ifn (internal-jacobian input)
-                          (if (number? output)
-                              (apply jacobian-generator naked-inputs js data)
-                              (apply
-                               produce-typed-array
-                               (lambda is
-                                 (apply jacobian-generator
-                                        naked-inputs
-                                        (append is js)
-                                        data))
-                               *atype*
-                               (array-dimensions output)))
-                          (let ((b ((internal-jacobian input) js))
-                                (fwd (internal-forward input)))
-                            (cond
-                             ((and (number? output) (number? fwd))
-                              (* (apply jacobian-generator
-                                        naked-inputs '() data)
-                                 b))
-                             ((and (number? output) (array? fwd))
-                              (array-ref
-                               (contract-arrays
-                                (apply
-                                 produce-typed-array
-                                 (lambda ks
-                                   (apply jacobian-generator
-                                          naked-inputs ks data))
-                                 *atype*
-                                 (array-dimensions fwd))
-                                b (array-rank fwd))))
-                             ((and (array? output) (number? fwd))
-                              (apply
-                               produce-typed-array
-                               (lambda is
-                                 ((extend *)
-                                  (apply jacobian-generator
-                                         naked-inputs is data)
-                                  b))
-                               *atype*
-                               (array-dimensions output)))
-                             (else
-                              (apply
-                               produce-typed-array
-                               (lambda is
-                                 (array-ref
-                                  (contract-arrays
-                                   (apply
-                                    produce-typed-array
-                                    (lambda ks
-                                      (apply jacobian-generator
-                                             naked-inputs
-                                             (append is ks)
-                                             data))
-                                    *atype*
-                                    (array-dimensions fwd))
-                                   b (array-rank fwd))))
-                               *atype*
-                               (array-dimensions output)))))))))
-             #f jacobian-generators inputs)
-            (if (number? output)
-                0
-                (apply make-typed-array *atype* 0 (array-dimensions output)))))))))
+        out
+        (let* ((data (if (procedure? proc*)
+                         '()
+                         (map (lambda (g)
+                                (apply g naked-args))
+                              (cdr proc*))))
+               (n-out-dims (apply * (dims-of out)))
+               (buf (make-array *unspecified* n-out-dims)))
+          (make-internal
+           out
+           (fold
+            (lambda (generator arg prev)
+              (if (or (not (internal? arg))
+                      (eq? 'zero (internal-jacobian arg)))
+                  prev
+                  (let ((Jx (internal-jacobian arg))
+                        (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
+                    (if (eq? Jx 'input)
+                        (if (eq? prev 'zero)
+                            (delay
+                              (mov buf n-out-dims
+                                   generator naked-args data (*j*)))
+                            (delay
+                              (add (force prev) n-out-dims
+                                   generator naked-args data (*j*))))
+                        (if (eq? prev 'zero)
+                            (delay
+                              (movc buf n-out-dims (force Jx) n-fwd-dims
+                                    generator naked-args data))
+                            (delay
+                              (addc (force prev) n-out-dims
+                                    (force Jx) n-fwd-dims
+                                    generator naked-args data)))))))
+            'zero generators args))))))
 
 (define (ewise1 f)
-  (lambda (xs indices)
+  (lambda (xs i j)
     (let ((x (car xs)))
       (if (number? x)
          (f x)
-         (receive (is js) (split-at indices (array-rank x))
-           (if (equal? is js)
-               (f (apply array-ref x is))
-               0))))))
+         (ifn (= i j)
+              0
+              (f (array-ref (array-contents x)
+                            j)))))))
 
 (define (ewise2 proc axis)
-  (lambda (xs indices)
+  (lambda (xs i j)
     (let ((x (car xs))
          (y (cadr xs)))
       (cond
@@ -329,25 +378,29 @@ element-wise. All arrays must have the same dimension."
        (proc x y))
        ((and (number? x) (array? y))
        (if (= axis 0)
-           (proc x (apply array-ref y indices))
-           (receive (is js) (split-at indices (array-rank y))
-             (ifn (equal? is js)
-                  0
-                  (proc x (apply array-ref y is))))))
+           (proc x (array-ref (array-contents y)
+                              i))
+           (ifn (= i j)
+                0
+                (proc x (array-ref (array-contents y)
+                                   j)))))
        ((and (array? x) (number? y))
        (if (= axis 1)
-           (proc (apply array-ref x indices) y)
-           (receive (is js) (split-at indices (array-rank x))
-             (ifn (equal? is js)
-                  0
-                  (proc (apply array-ref x is)
-                        y)))))
+           (proc (array-ref (array-contents x)
+                            i)
+                 y)
+           (ifn (= i j)
+                0
+                (proc (array-ref (array-contents x)
+                                 j)
+                      y))))
        (else
-       (receive (is js) (split-at indices (array-rank x))
-         (ifn (equal? is js)
-              0
-              (proc (apply array-ref x is)
-                    (apply array-ref y is)))))))))
+       (ifn (= i j)
+            0
+            (proc (array-ref (array-contents x)
+                             j)
+                  (array-ref (array-contents y)
+                             j))))))))
 
 (define (i:identity x)
   "Differentiable identity."
@@ -493,19 +546,19 @@ element-wise. All arrays must have the same dimension."
   "Differentiable mean on arrays."
   (differentiable-wrapper
    (list
-    (lambda (xs indices one-over-n)
+    (lambda (xs i j one-over-n)
       one-over-n))
    (let ((n 0))
-    (list
-     (lambda (x)
-       (let ((sum 0))
-        (array-for-each
-         (lambda (x)
-           (set! sum (+ sum x))
-           (set! n (1+ n)))
-         x)
-        (/ sum n)))
-     (lambda _ (/ n))))
+     (list
+      (lambda (x)
+       (let ((sum 0))
+         (array-for-each
+          (lambda (x)
+            (set! sum (+ sum x))
+            (set! n (1+ n)))
+          x)
+         (/ sum n)))
+      (lambda _ (/ n))))
    x))
 
 (define (i:array-ref x . indices)
@@ -513,20 +566,72 @@ element-wise. All arrays must have the same dimension."
   (apply
    differentiable-wrapper
    (cons
-    (lambda (xs js)
-      (if (equal? js indices)
+    (lambda (xs i j abs-index)
+      (if (= j abs-index)
          1
          0))
-    (list-tabulate (length indices) not))
-   array-ref
+    (map not indices))
+   (list
+    array-ref
+    (lambda (x . indices)
+      (rel->abs indices (array-dimensions x))))
+   x indices))
+
+(define (i:array-cell-ref x . indices)
+  (apply
+   differentiable-wrapper
+   (cons
+    (lambda (xs i j abs-index n-rst-dims)
+      (receive (j-ref j-rst) (euclidean/ j n-rst-dims)
+       (if (and (= j-ref abs-index)
+                (= j-rst i))
+           1
+           0)))
+    (map not indices))
+   (list
+    array-cell-ref
+    (lambda (x . indices)
+      (rel->abs indices (take (array-dimensions x)
+                             (length indices))))
+    (lambda (x . indices)
+      (apply * (drop (array-dimensions x)
+                    (length indices)))))
    x indices))
 
+(define (make-batch elem . more)
+  (let ((batch-size (1+ (length more))))
+    (apply
+     differentiable-wrapper
+     (list-tabulate
+      batch-size
+      (lambda (b)
+       (lambda (xs i j n-rest-dims)
+         (receive (i-batch i-rest) (euclidean/ i n-rest-dims)
+           (if (and (= i-batch b)
+                    (= i-rest j))
+               1
+               0
+               )))))
+     (list
+      (lambda (elem . more)
+       (let ((a (apply make-typed-array *atype* *unspecified* batch-size
+                       (dims-of elem))))
+         (for-each
+          (lambda (x b)
+            (array-cell-set! a x b))
+          (cons elem more)
+          (list-tabulate batch-size identity))
+         a))
+      (lambda (elem . more)
+       (apply * (dims-of elem))))
+     elem more)))
+
 (define (maximum x)
   "Differentiable maximum on arrays."
   (differentiable-wrapper
    (list
-    (lambda (xs indices max-indices)
-      (if (equal? indices max-indices)
+    (lambda (xs i j max-index)
+      (if (= j max-index)
          1
          0)))
    (let ((max-index 'TBD))
@@ -542,8 +647,7 @@ element-wise. All arrays must have the same dimension."
             (set! i (1+ i)))
           x)
          m))
-      (lambda (x)
-       (abs->rel max-index (array-dimensions x)))))
+      (lambda _ max-index)))
    x))
 
 (define (sum x)
@@ -559,115 +663,37 @@ element-wise. All arrays must have the same dimension."
        sum))
    x))
 
-(define (amap2 f x y)
-  (define (unbox-with proc x)
-    (ifn (internal? x)
-        x
-        (proc x)))
-  (define (dims-of x)
-    (if (number? x)
-       '()
-       (array-dimensions x)))
-  (define (boxed-ref x i)
-    (let ((xi (array-cell-ref (unbox-with internal-forward x)
-                             i)))
-      (ifn (internal? x)
-          xi
-          (make-internal
-           xi
-           (if (internal-jacobian x)
-               (lambda (js)
-                 (array-cell-ref
-                  ((internal-jacobian x)
-                   js)
-                  i))
-               (lambda (js)
-                 (if (number? xi)
-                     (if (= i (car js))
-                         1
-                         0)
-                     (apply
-                      produce-typed-array
-                      (lambda indices
-                        (if (and (= i (car js))
-                                 (equal? indices (cdr js)))
-                            1
-                            0))
-                      (array-type xi)
-                      (array-dimensions xi)))))))))
-  (define (boxed-fi f i x y)
-    (f (boxed-ref x i)
-       (boxed-ref y i)))
-  (if (internal? f)
-      (error "unsupported application of `amap2' to internal object" f)
-      (let ((bs (first (array-dimensions (unbox-with internal-forward x))))
-           (f0 (boxed-fi f 0 x y)))
-       (ifn (internal? f0)
-            (let ((fwd (apply make-typed-array
-                              ;; TODO: use the correct type based on f0.
-                              *atype* *unspecified* bs (dims-of f0))))
-              (for-indices-in-range
-               (lambda (i)
-                 (array-cell-set! fwd (boxed-fi f i x y) i))
-               '(0) (list bs))
-              fwd)
-            (let ((jac (make-array *unspecified* bs))
-                  (fwd (apply make-typed-array
-                              ;; TODO: use the correct type based on f0.
-                              *atype* *unspecified*
-                              bs (dims-of (internal-forward f0)))))
-              (for-indices-in-range
-               (lambda (i)
-                 (let* ((fi (boxed-fi f i x y))
-                        (Jfi (internal-jacobian fi)))
-                   (array-set! jac (internal-jacobian fi) i)
-                   (array-cell-set! fwd (internal-forward fi) i)))
-               '(0) (list bs))
-              (make-internal
-               fwd
-               (lambda (js)
-                 (let ((a (apply make-typed-array
-                                 ;; TODO: use the correct type based on f0.
-                                 *atype* *unspecified*
-                                 bs (dims-of (internal-forward f0)))))
-                   (for-indices-in-range
-                    (lambda (batch)
-                      (array-cell-set!
-                       a ((array-ref jac batch) js)
-                       batch))
-                    '(0) (list bs))
-                   a))))))))
-
 (define (adot x y n)
   (differentiable-wrapper
    (list
-    (lambda (xs indices free-rank-x free-rank-y)
-      (let* ((y (cadr xs))
-            (n (caddr xs))
-            (out (take indices (+ free-rank-x free-rank-y)))
-            (in (drop indices (+ free-rank-x free-rank-y)))
-            (is-free1 (take out free-rank-x))
-            (js-free (drop out free-rank-x))
-            (is-free2 (take in free-rank-x))
-            (is-bound (drop in free-rank-x)))
-       (ifn (equal? is-free1 is-free2)
-            0
-            (apply array-ref y (append is-bound js-free)))))
-    (lambda (xs indices free-rank-x free-rank-y)
-      (let* ((x (car xs))
-            (n (caddr xs))
-            (out (take indices (+ free-rank-x free-rank-y)))
-            (in (drop indices (+ free-rank-x free-rank-y)))
-            (is-free (take out free-rank-x))
-            (js-free1 (drop out free-rank-x))
-            (js-free2 (drop in n))
-            (js-bound (take in n)))
-       (ifn (equal? js-free1 js-free2)
-            0
-            (apply array-ref x (append is-free js-bound)))))
+    (lambda (xs i j n-free-dims-y n-bound-dims)
+      (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+       (receive (j-free j-bound) (euclidean/ j n-bound-dims)
+         (ifn (= i-x j-free)
+              0
+              (array-ref (array-contents (cadr xs))
+                         (+ i-y (* n-free-dims-y j-bound)))))))
+    (lambda (xs i j n-free-dims-y n-bound-dims)
+      (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+       (receive (j-bound j-free) (euclidean/ j n-free-dims-y)
+         (ifn (= i-y j-free)
+              0
+              (array-ref (array-contents (car xs))
+                         (+ j-bound (* n-bound-dims i-x)))))))
     #f)
    (list
     contract-arrays
-    (lambda (x y n) (- (array-rank x) n))
-    (lambda (x y n) (- (array-rank y) n)))
+    (lambda (x y n)
+      (apply * (drop (array-dimensions y)
+                    n)))
+    (lambda (x y n)
+      (apply * (take (array-dimensions y)
+                    n))))
    x y n))
+
+(define (amap2 f x y)
+  (apply make-batch
+        (list-tabulate (car (dims-of (unwrap-fwd x)))
+                       (lambda (b)
+                         (f (i:array-cell-ref x b)
+                            (i:array-cell-ref y b))))))
index 0437d7ca425cf40c092aec21d74a66d73bb31653..88f00731b8ca1254e65fe9b76c652a777bf0a2be 100644 (file)
--- a/misc.scm
+++ b/misc.scm
@@ -46,14 +46,14 @@ order."
 (define (for-indices-in-range f starts ends)
   (define (for-indices-in-range% f indices starts ends)
     (if (null? starts)
-      (apply f (reverse indices))
-      (do ((i (car starts) (1+ i)))
-        ((= i (car ends)))
-        (for-indices-in-range%
-          f
-          (cons i indices)
-          (cdr starts)
-          (cdr ends)))))
+       (apply f (reverse indices))
+       (do ((i (car starts) (1+ i)))
+           ((= i (car ends)))
+         (for-indices-in-range%
+          f
+           (cons i indices)
+           (cdr starts)
+           (cdr ends)))))
   (for-indices-in-range% f '() starts ends))
 
 ;;;; array utilities