]> git.vouivredigital.com Git - vouivre.git/commitdiff
Precompute values for jacobian generators
authoradmin <admin@vouivredigital.com>
Thu, 9 Nov 2023 04:18:54 +0000 (13:18 +0900)
committeradmin <admin@vouivredigital.com>
Thu, 9 Nov 2023 04:18:54 +0000 (13:18 +0900)
grad.scm

index 46b1a999b20c747417753bb807292cc774c930f6..8f526fd4a3fb6f5bdb302409e64e567abc85b635 100644 (file)
--- a/grad.scm
+++ b/grad.scm
 
 ;;;; array utilities
 
+(define (abs->rel index dimensions)
+  (let rec ((ds (cdr dimensions))
+           (p (apply * (cdr dimensions)))
+           (r index)
+           (is '()))
+    (let ((i (quotient r p)))
+      (if (null? ds)
+         (reverse (cons i is))
+         (rec (cdr ds)
+              (quotient p (car ds))
+              (- r (* p i))
+              (cons i is))))))
+
 (define (contracted-dims a b n)
   (let ((dims-a (array-dimensions a))
         (dims-b (array-dimensions b)))
@@ -208,7 +221,14 @@ element-wise. All arrays must have the same dimension."
   ;;       corresponding generator should be `#f'.
   (let* ((inputs (cons input more))
         (naked-inputs (map unbox-fwd inputs))
-        (output (apply function naked-inputs)))
+        (output (apply (if (procedure? function)
+                           function
+                           (car function))
+                       naked-inputs))
+        (data (if (procedure? function)
+                  '()
+                  (map (lambda (f) (apply f naked-inputs))
+                       (cdr function)))))
     (ifn (*grad*)
         output
         (make-internal
@@ -222,18 +242,22 @@ element-wise. All arrays must have the same dimension."
                     ((ifn prev identity (lambda (x) ((extend +) prev x)))
                      (ifn (internal-jacobian input)
                           (if (number? output)
-                              (jacobian-generator naked-inputs js)
+                              (apply jacobian-generator naked-inputs js data)
                               (apply
                                produce-typed-array
                                (lambda is
-                                 (jacobian-generator naked-inputs (append is js)))
+                                 (apply jacobian-generator
+                                        naked-inputs
+                                        (append is js)
+                                        data))
                                *atype*
                                (array-dimensions output)))
                           (let ((b ((internal-jacobian input) js))
                                 (fwd (internal-forward input)))
                             (cond
                              ((and (number? output) (number? fwd))
-                              (* (jacobian-generator naked-inputs '())
+                              (* (apply jacobian-generator
+                                        naked-inputs '() data)
                                  b))
                              ((and (number? output) (array? fwd))
                               (array-ref
@@ -241,7 +265,8 @@ element-wise. All arrays must have the same dimension."
                                 (apply
                                  produce-typed-array
                                  (lambda ks
-                                   (jacobian-generator naked-inputs ks))
+                                   (apply jacobian-generator
+                                          naked-inputs ks data))
                                  *atype*
                                  (array-dimensions fwd))
                                 b (array-rank fwd))))
@@ -250,7 +275,8 @@ element-wise. All arrays must have the same dimension."
                                produce-typed-array
                                (lambda is
                                  ((extend *)
-                                  (jacobian-generator naked-inputs is)
+                                  (apply jacobian-generator
+                                         naked-inputs is data)
                                   b))
                                *atype*
                                (array-dimensions output)))
@@ -263,7 +289,10 @@ element-wise. All arrays must have the same dimension."
                                    (apply
                                     produce-typed-array
                                     (lambda ks
-                                      (jacobian-generator naked-inputs (append is ks)))
+                                      (apply jacobian-generator
+                                             naked-inputs
+                                             (append is ks)
+                                             data))
                                     *atype*
                                     (array-dimensions fwd))
                                    b (array-rank fwd))))
@@ -457,17 +486,19 @@ element-wise. All arrays must have the same dimension."
   "Differentiable mean on arrays."
   (differentiable-wrapper
    (list
-    (lambda (xs indices)
-      (/ 1 (apply * (array-dimensions (car xs))))))
-   (lambda (x)
-     (let ((sum 0)
-          (count 0))
-       (array-for-each
-       (lambda (x)
-         (set! sum (+ sum x))
-         (set! count (1+ count)))
-       x)
-       (/ sum count)))
+    (lambda (xs indices one-over-n)
+      one-over-n))
+   (let ((n 0))
+    (list
+     (lambda (x)
+       (let ((sum 0))
+        (array-for-each
+         (lambda (x)
+           (set! sum (+ sum x))
+           (set! n (1+ n)))
+         x)
+        (/ sum n)))
+     (lambda _ (/ n))))
    x))
 
 (define (i:array-ref x . indices)
@@ -483,48 +514,35 @@ element-wise. All arrays must have the same dimension."
    array-ref
    x indices))
 
-(define (for-flat-index proc a)
-  (let* ((b (array-contents a))
-        (n (array-length b)))
-    (let rec ((i 0))
-      (unless (= i n)
-       (proc b i)
-       (rec (1+ i))))))
-
 (define (maximum x)
   "Differentiable maximum on arrays."
   (differentiable-wrapper
    (list
-    (lambda (xs indices)
-      (let ((x (car xs))
-           (m (- (inf)))
-           (i 'TBD))
-       (for-indices-in-range
-        (lambda indices
-          (let ((xi (apply array-ref x indices)))
-            (when (< m xi)
-              (set! m xi)
-              (set! i indices))))
-        (list-zeros (array-rank x))
-        (array-dimensions x))
-       (if (equal? i indices)
-           1
-           0))))
-   (lambda (x)
-     (let ((m (- (inf))))
-       (array-for-each
-       (lambda (x)
-         (set! m (max m x)))
-       x)
-       m))
+    (lambda (xs indices max-indices)
+      (if (equal? indices max-indices)
+         1
+         0)))
+   (let ((max-index 'TBD))
+     (list
+      (lambda (x)
+       (let ((m (- (inf)))
+             (i 0))
+         (array-for-each
+          (lambda (x)
+            (when (< m x)
+              (set! m x)
+              (set! max-index i))
+            (set! i (1+ i)))
+          x)
+         m))
+      (lambda (x)
+       (abs->rel max-index (array-dimensions x)))))
    x))
 
 (define (sum x)
   "Differentiable sum on arrays."
   (differentiable-wrapper
-   (list
-    (lambda (xs indices)
-      1))
+   (list (lambda _ 1))
    (lambda (x)
      (let ((sum 0))
        (array-for-each
@@ -616,14 +634,9 @@ element-wise. All arrays must have the same dimension."
 (define (adot x y n)
   (differentiable-wrapper
    (list
-    (lambda (xs indices)
-      (let* ((x (car xs))
-            (y (cadr xs))
+    (lambda (xs indices free-rank-x free-rank-y)
+      (let* ((y (cadr xs))
             (n (caddr xs))
-            (free-rank-x (- (array-rank x)
-                            n))
-            (free-rank-y (- (array-rank y)
-                            n))
             (out (take indices (+ free-rank-x free-rank-y)))
             (in (drop indices (+ free-rank-x free-rank-y)))
             (is-free1 (take out free-rank-x))
@@ -633,14 +646,9 @@ element-wise. All arrays must have the same dimension."
        (ifn (equal? is-free1 is-free2)
             0
             (apply array-ref y (append is-bound js-free)))))
-    (lambda (xs indices)
+    (lambda (xs indices free-rank-x free-rank-y)
       (let* ((x (car xs))
-            (y (cadr xs))
             (n (caddr xs))
-            (free-rank-x (- (array-rank x)
-                            n))
-            (free-rank-y (- (array-rank y)
-                            n))
             (out (take indices (+ free-rank-x free-rank-y)))
             (in (drop indices (+ free-rank-x free-rank-y)))
             (is-free (take out free-rank-x))
@@ -651,4 +659,8 @@ element-wise. All arrays must have the same dimension."
             0
             (apply array-ref x (append is-free js-bound)))))
     #f)
-   contract-arrays x y n))
+   (list
+    contract-arrays
+    (lambda (x y n) (- (array-rank x) n))
+    (lambda (x y n) (- (array-rank y) n)))
+   x y n))