]> git.vouivredigital.com Git - vouivre.git/commitdiff
Optimize the gradient algorithm
authoradmin <admin@vouivredigital.com>
Sun, 5 Nov 2023 13:17:58 +0000 (22:17 +0900)
committeradmin <admin@vouivredigital.com>
Sun, 5 Nov 2023 13:17:58 +0000 (22:17 +0900)
With this optimization the jacobians are computed column by
column improving both space and time complexity.

grad.scm

index 694aa558ebed162fa90098eb4ab09cfa8123d000..81484a08400dc38b9e59884fa6adee3e19dadd0f 100644 (file)
--- a/grad.scm
+++ b/grad.scm
@@ -1,4 +1,5 @@
 (define-module (vouivre grad)
+  #:use-module (ice-9 receive)
   #:use-module (srfi srfi-1)
   #:use-module (srfi srfi-9)
   #:use-module (vouivre misc)
@@ -11,8 +12,6 @@
    grad
    internal-jacobian
    mean
-   mirror
-   one
    rank-of)
   #:replace
   ((i:* . *)
    (i:identity . identity)
    (i:max . max)))
 
-;;;; misc utilities
-
-(define (one . args)
-  1)
-
 ;;;; array utilities
 
 (define (contracted-dims a b n)
@@ -104,28 +98,6 @@ element-wise. All arrays must have the same dimension."
     ((extend *) x y))
    (else (error "can't dot because of invalid types or ranks" x y n))))
 
-(define (mirror f x . args)
-  "Create a scalar/array of shape [x]:[x] where element with index `is:is' has
-the value of `f' evaluated at the `is' elements of `args' and all other elements
-being zero."
-  (cond
-   ((number? x)
-    (apply f args))
-   ((array? x)
-    (let ((n (array-rank x))
-         (dims (array-dimensions x)))
-      (apply
-       produce-array
-       (lambda indices
-        (if (equal? (take indices n)
-                    (drop indices n))
-            (apply f (map (lambda (arg)
-                            (apply array-ref arg (drop indices n)))
-                          args))
-            0))
-       (append dims dims))))
-   (else (error "expected array or number, got " x))))
-
 (define (rank-of x)
   (if (number? x)
       0
@@ -152,126 +124,219 @@ being zero."
 ;;;; differentiation
 
 (define-record-type internal
-  (make-internal jacobian forward)
+  (make-internal forward jacobian)
   internal?
-  (jacobian internal-jacobian set-internal-jacobian!)
-  (forward internal-forward))
+  (forward internal-forward)
+  (jacobian internal-jacobian))
 
 (define *grad* (make-parameter #f))
 
 (define* (grad f #:optional (axis 0))
   (define (wrap x i)
     (if (= i axis)
-       (make-internal (mirror one x) x)
+       (make-internal x #f)
        x))
   (lambda xs
-    ((if (*grad*) identity internal-jacobian)
-     (let ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity))))
-       (parameterize (((@@ (vouivre grad) *grad*) (list-ref wrapped-xs axis)))
+    (let* ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity)))
+          (y (internal-forward (list-ref wrapped-xs axis))))
+      ((if (*grad*)
+          identity
+          (lambda (boxed)
+            (let ((fx (internal-forward boxed))
+                  (gfx (internal-jacobian boxed)))
+              (cond
+               ((and (number? fx) (number? y))
+                (gfx '()))
+               ((and (number? fx) (array? y))
+                (apply produce-array
+                       (lambda js (gfx js))
+                       (array-dimensions y)))
+               ((and (array? fx) (number? y))
+                (gfx '()))
+               (else
+                (let* ((dims-out (array-dimensions fx))
+                       (dims-in (array-dimensions y))
+                       (a (apply make-array 0 (append dims-out dims-in))))
+                  (for-indices-in-range
+                   (lambda js
+                     (let ((out (gfx js)))
+                       (for-indices-in-range
+                        (lambda is
+                          (apply
+                           array-set!
+                           a
+                           (apply array-ref out is)
+                           (append is js)))
+                        (list-zeros (array-rank fx))
+                        dims-out)))
+                   (list-zeros (array-rank y))
+                   dims-in)
+                  a)
+                )))))
+       (parameterize (((@@ (vouivre grad) *grad*) y))
         (apply f wrapped-xs))))))
 
-(define (grad-input)
-  (internal-forward (*grad*)))
+(define (unbox-fwd x)
+  (if (internal? x)
+      (internal-forward x)
+      x))
 
-(define (differentiable-wrapper jacobian-generator function input . more)
-  ;; NOTE: Both the jacobian generator and the function act on naked inputs
-  ;;       (numbers or arrays not inside an internal object). The generator
-  ;;       returns a list of jacobians, each one expressing the change of the
-  ;;       function's output when changing the corresponding input. In cases
-  ;;       where an argument isn't meant to be differentiable its corresponding
-  ;;       element in the list should be `#f'.
+;; `n' is the number of arguments to `function'.
+;; `jacobian-generators is not a `Vec' but a `List' we only use the former to
+;; show its length.
+;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
+;; `I' is the type of multi-indices indexing the output of `function'.
+;; `J' is the type of multi-indices indexing the input array being differentiated.
+;; `Array I' is the type of arrays indexed by multi-indices of `I'.
+;; `[X]' means that `X' is boxed in an internal as when returned by
+;;       `differentiable-wrapper' with the array being `X' and the function
+;;       taking a multi-index of `J' to return an array of the same shape as `X'.
+;; (∷ (→ (Vec n (→ X1 ... Xn I J Number))
+;;       (→ X1 ... Xn (Array I))
+;;       [X1] ... [Xn]
+;;       (Internal (Array I) (→ J (Array I)))))
+(define (differentiable-wrapper jacobian-generators function input . more)
+  ;; NOTE: Both the jacobian generators and the function act on naked inputs
+  ;;       (numbers or arrays not inside an internal object). The generators
+  ;;       additionaly take indices -- one for each dimension of what we are
+  ;;       differentiating with respect to (the content of `(*grad*)'). A
+  ;;       generator returns a jacobian column expressing the change of the
+  ;;       function's output when changing `(*grad*)' along the given indices.
+  ;;       In cases where an argument isn't meant to be differentiable its
+  ;;       corresponding generator should be `#f'.
   (let* ((inputs (cons input more))
-        (naked-inputs
-         (map
-          (lambda (x)
-            (if (internal? x)
-                (internal-forward x)
-                x))
-          inputs))
+        (naked-inputs (map unbox-fwd inputs))
         (output (apply function naked-inputs)))
-    (if (*grad*)
-       (make-internal
-        (or
-         (fold
-          (lambda (Jx x prev)
-            (if (not (internal? x))
-                prev
-                ((if (not prev)
-                     identity
-                     (lambda (x) ((extend +) prev x)))
-                 (dot Jx
-                      (internal-jacobian x)
-                      (rank-of (internal-forward x))))))
-          #f
-          (apply jacobian-generator naked-inputs)
-          inputs)
-         (zeros-out:in output (grad-input)))
-        output)
-       output)))
+    (ifn (*grad*)
+        output
+        (make-internal
+         output
+         (lambda (js)
+           (or
+            (fold
+             (lambda (jacobian-generator input prev)
+               (ifn (internal? input)
+                    prev
+                    ((ifn prev identity (lambda (x) ((extend +) prev x)))
+                     (ifn (internal-jacobian input)
+                          (if (number? output)
+                              (apply jacobian-generator
+                                     (append naked-inputs js))
+                              (apply
+                               produce-array
+                               (lambda is
+                                 (apply jacobian-generator
+                                        (append naked-inputs is js)))
+                               (array-dimensions output)))
+                          (let ((b ((internal-jacobian input) js))
+                                (fwd (internal-forward input)))
+                            (cond
+                             ((and (number? output) (number? fwd))
+                              (* (apply jacobian-generator naked-inputs)
+                                 b))
+                             ((and (number? output) (array? fwd))
+                              (array-ref
+                               (contract-arrays
+                                (apply
+                                 produce-array
+                                 (lambda ks
+                                   (apply jacobian-generator
+                                          (append naked-inputs ks)))
+                                 (array-dimensions fwd))
+                                b (array-rank fwd))))
+                             ((and (array? output) (number? fwd))
+                              (apply
+                               produce-array
+                               (lambda is
+                                 ((extend *)
+                                  (apply jacobian-generator
+                                         (append naked-inputs is))
+                                  b))
+                               (array-dimensions output)))
+                             (else
+                              (apply
+                               produce-array
+                               (lambda is
+                                 (array-ref
+                                  (contract-arrays
+                                   (apply
+                                    produce-array
+                                    (lambda ks
+                                      (apply jacobian-generator
+                                             (append naked-inputs is ks)))
+                                    (array-dimensions fwd))
+                                   b (array-rank fwd))))
+                               (array-dimensions output)))))))))
+             #f jacobian-generators inputs)
+            (if (number? output)
+                0
+                (apply make-array 0 (array-dimensions output)))))))))
+
+(define (ewise1 f)
+  (lambda (x . indices)
+    (if (number? x)
+       (f x)
+       (receive (is js) (split-at indices (array-rank x))
+         (if (equal? is js)
+             (f (apply array-ref x is))
+             0)))))
+
+(define (ewise2 proc axis)
+  (lambda (x y . indices)
+    (cond
+     ((and (number? x) (number? y))
+      (proc x y))
+     ((and (number? x) (array? y))
+      (if (= axis 0)
+         (proc x (apply array-ref y indices))
+         (receive (is js) (split-at indices (array-rank y))
+           (ifn (equal? is js)
+                0
+                (proc x (apply array-ref y is))))))
+     ((and (array? x) (number? y))
+      (if (= axis 1)
+         (proc (apply array-ref x indices) y)
+         (receive (is js) (split-at indices (array-rank x))
+           (ifn (equal? is js)
+                0
+                (proc (apply array-ref x is)
+                      y)))))
+     (else
+      (receive (is js) (split-at indices (array-rank x))
+       (ifn (equal? is js)
+            0
+            (proc (apply array-ref x is)
+                  (apply array-ref y is))))))))
 
 (define (i:identity x)
   "Differentiable identity."
   (differentiable-wrapper
-   (lambda (x) (list (mirror one x)))
+   (list (ewise1 (lambda _ 1)))
    identity
    x))
 
 (define (i:exp x)
   "Differentiable exponential."
   (differentiable-wrapper
-   (lambda (x) (list (mirror exp x x)))
+   (list (ewise1 exp))
    (extend exp)
    x))
 
 (define (i:* x y)
   "Differentiable element-wise multiplication."
   (differentiable-wrapper
-   (lambda (x y)
-     (list
-      (cond
-       ((and (number? y) (number? x))
-       y)
-       ((and (number? y) (array? x))
-       (mirror (lambda _ y) x))
-       ((and (array? y) (number? x))
-       y)
-       (else
-       (mirror identity x y)))
-      (cond
-       ((and (number? x) (number? y))
-       x)
-       ((and (number? x) (array? y))
-       (mirror (lambda _ x) y))
-       ((and (array? x) (number? y))
-       x)
-       (else
-       (mirror identity y x)))))
+   (list
+    (ewise2 (lambda (x y) y) 0)
+    (ewise2 (lambda (x y) x) 1))
    (extend *)
    x y))
 
 (define (i:- x y)
   "Differentiable element-wise subtraction."
   (differentiable-wrapper
-   (lambda (x y)
-     (list
-      (cond
-       ((and (number? y) (number? x))
-       +1)
-       ((and (number? y) (array? x))
-       (mirror one x))
-       ((and (array? y) (number? x))
-       (apply make-array 1 (array-dimensions y)))
-       (else
-       (mirror one x)))
-      (cond
-       ((and (number? x) (number? y))
-       -1)
-       ((and (number? x) (array? y))
-       (mirror (lambda _ -1) y))
-       ((and (array? x) (number? y))
-       (apply make-array -1 (array-dimensions x)))
-       (else
-       (mirror (lambda _ -1) y)))))
+   (list
+    (ewise2 (lambda _ +1) 0)
+    (ewise2 (lambda _ -1) 1))
    (extend -)
    x y))
 
@@ -286,39 +351,18 @@ being zero."
      (else
       0)))
   (differentiable-wrapper
-   (lambda (x y)
-     (list
-      (cond
-       ((and (number? y) (number? x))
-       (dmax x y))
-       ((and (number? y) (array? x))
-       (mirror (lambda (x) (dmax x y)) x x))
-       ((and (array? y) (number? x))
-       (array-map (lambda (y) (dmax x y)) y))
-       (else
-       (mirror (lambda (x y) (dmax x y)) x x y)))
-      (cond
-       ((and (number? x) (number? y))
-       (dmax y x))
-       ((and (number? x) (array? y))
-       (mirror (lambda (y) (dmax y x)) y y))
-       ((and (array? x) (number? y))
-       (array-map (lambda (x) (dmax y x)) x))
-       (else
-       (mirror (lambda (y x) (dmax y x)) y y x)))))
+   (list
+    (ewise2 dmax 0)
+    (ewise2 (flip dmax) 1))
    (extend max)
    x y))
 
 (define (mean x)
   "Differentiable mean on arrays."
   (differentiable-wrapper
-   (lambda (x)
-     (let ((dims (array-dimensions x)))
-       (list
-       (apply
-        make-array
-        (/ 1 (apply * dims))
-        dims))))
+   (list
+    (lambda (x . indices)
+      (/ 1 (apply * (array-dimensions x)))))
    (lambda (x)
      (let ((sum 0)
           (count 0))
@@ -330,25 +374,50 @@ being zero."
        (/ sum count)))
    x))
 
+;;        ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30))
+;;        (let ((x #2((1 2) (3 4) (5 6))) (y #2((1 2) (3 4) (5 6)))) ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)))
 (define (amap2 f x y)
   (define (unbox-with proc x)
     (ifn (internal? x)
         x
         (proc x)))
+(define (dims-of x)
+    (if (number? x)
+       '()
+       (array-dimensions x)))
   (define (boxed-ref x i)
     (ifn (internal? x)
         (array-cell-ref x i)
-        (make-internal (array-cell-ref (internal-jacobian x)
+        (make-internal (array-cell-ref (internal-forward x)
                                        i)
-                       (array-cell-ref (internal-forward x)
-                                       i))))
+                       (if (internal-jacobian x)
+                           (lambda (js)
+                             (array-cell-ref
+                              ((internal-jacobian x)
+                               js)
+                              i))
+                           (lambda (js)
+                             (let* ((x (internal-forward x))
+                                    (xi (array-cell-ref x i)))
+                               (if (number? xi)
+                                   (if (= i (car js))
+                                       1
+                                       0)
+                                   (let ((a (apply make-array 0 (dims-of xi))))
+                                     (for-indices-in-range
+                                      (lambda out
+                                        (when (and (= i (car js))
+                                                   (equal? out (cdr js)))
+                                          (apply
+                                           array-set!
+                                           a 1
+                                           out)))
+                                      (list-zeros (rank-of xi))
+                                      (dims-of xi))
+                                     a))))))))
   (define (boxed-fi f i x y)
     (f (boxed-ref x i)
        (boxed-ref y i)))
-  (define (dims-of x)
-    (if (number? x)
-       '()
-       (array-dimensions x)))
   (if (internal? f)
       (error "unsupported application of `amap2' to internal object" f)
       (let ((bs (first (array-dimensions (unbox-with internal-forward x))))
@@ -360,54 +429,57 @@ being zero."
                  (array-cell-set! fwd (boxed-fi f i x y) i))
                '(0) (list bs))
               fwd)
-            (let ((jac (apply make-array 0 bs (dims-of (internal-jacobian f0))))
+            (let ((jac (make-array 0 bs))
                   (fwd (apply make-array 0 bs (dims-of (internal-forward f0)))))
               (for-indices-in-range
                (lambda (i)
-                 (let ((fi (boxed-fi f i x y)))
-                   (array-cell-set! jac (internal-jacobian fi) i)
+                 (let* ((fi (boxed-fi f i x y))
+                        (Jfi (internal-jacobian fi)))
+                   (array-set! jac (internal-jacobian fi) i)
                    (array-cell-set! fwd (internal-forward fi) i)))
                '(0) (list bs))
-              (make-internal jac fwd))))))
+              (make-internal
+               fwd
+               (lambda (js)
+                 (let ((a (apply make-array 0 bs (dims-of (internal-forward f0)))))
+                   (for-indices-in-range
+                    (lambda (batch)
+                      (array-cell-set!
+                       a ((array-ref jac batch) js)
+                       batch))
+                    '(0) (list bs))
+                   a))))))))
 
 (define (adot x y n)
   (differentiable-wrapper
-   (lambda (x y n)
-     (let ((bound-dims (contracted-dims x y n)))
-       (list
-       (let ((r (apply make-array 0 (append bound-dims (array-dimensions x)))))
-         (for-indices-in-range
-          (lambda free-indices
-            (let ((free-indices-x (take free-indices (- (array-rank x) n)))
-                  (free-indices-y (drop free-indices (- (array-rank x) n))))
-              (for-indices-in-range
-               (lambda bound-indices
-                 (apply
-                  array-set!
-                  r
-                  (apply array-ref y (append bound-indices free-indices-y))
-                  (append free-indices free-indices-x bound-indices)))
-               (list-zeros n)
-               (take (array-dimensions y) n))))
-          (list-zeros (length bound-dims))
-          bound-dims)
-         r)
-       (let ((r (apply make-array 0 (append bound-dims (array-dimensions y)))))
-         (for-indices-in-range
-          (lambda free-indices
-            (let ((free-indices-x (take free-indices (- (array-rank x) n)))
-                  (free-indices-y (drop free-indices (- (array-rank x) n))))
-              (for-indices-in-range
-               (lambda bound-indices
-                 (apply
-                  array-set!
-                  r
-                  (apply array-ref x (append free-indices-x bound-indices))
-                  (append free-indices bound-indices free-indices-y)))
-               (list-zeros n)
-               (take (array-dimensions y) n))))
-          (list-zeros (length bound-dims))
-          bound-dims)
-         r)
-       #f)))
+   (list
+    (lambda (x y n . indices)
+      (let* ((free-rank-x (- (array-rank x)
+                            n))
+            (free-rank-y (- (array-rank y)
+                            n))
+            (out (take indices (+ free-rank-x free-rank-y)))
+            (in (drop indices (+ free-rank-x free-rank-y)))
+            (is-free1 (take out free-rank-x))
+            (js-free (drop out free-rank-x))
+            (is-free2 (take in free-rank-x))
+            (is-bound (drop in free-rank-x)))
+       (ifn (equal? is-free1 is-free2)
+            0
+            (apply array-ref y (append is-bound js-free)))))
+    (lambda (x y n . indices)
+      (let* ((free-rank-x (- (array-rank x)
+                            n))
+            (free-rank-y (- (array-rank y)
+                            n))
+            (out (take indices (+ free-rank-x free-rank-y)))
+            (in (drop indices (+ free-rank-x free-rank-y)))
+            (is-free (take out free-rank-x))
+            (js-free1 (drop out free-rank-x))
+            (js-free2 (drop in n))
+            (js-bound (take in n)))
+       (ifn (equal? js-free1 js-free2)
+            0
+            (apply array-ref x (append is-free js-bound)))))
+    #f)
    contract-arrays x y n))