]> git.vouivredigital.com Git - vouivre.git/commitdiff
Optimize `contract-arrays' using absolute indexing
authoradmin <admin@vouivredigital.com>
Fri, 10 Nov 2023 07:35:53 +0000 (16:35 +0900)
committeradmin <admin@vouivredigital.com>
Fri, 10 Nov 2023 07:35:53 +0000 (16:35 +0900)
grad.scm

index e93a3129f17b1aa5660999ec00b119842a5a4e3a..f0aa3c24afc4ce1d4780140b801e074f3d9802e5 100644 (file)
--- a/grad.scm
+++ b/grad.scm
@@ -7,6 +7,7 @@
   (*atype*
    adot
    amap2
+   contract-arrays
    differentiable-wrapper
    dot
    extend
@@ -15,7 +16,8 @@
    maximum
    mean
    rank-of
-   sum)
+   sum
+   )
   #:replace
   ((i:sqrt . sqrt)
    (i:exp . exp)
                              n))
              (drop dims-b n)))))
 
+(define(do-times n proc)
+  (let rec ((i 0))
+    (unless (= i n)
+      (proc i)
+      (rec (1+ i)))))
+
 (define (contract-arrays a b n)
-  (let* ((dims (contracted-dims a b n))
-         (dims-b (array-dimensions b))
-         (nb-dims-a (array-rank a))
-         (nb-dims-b (array-rank b))
-         (nb-fix-dims-a (- nb-dims-a n))
-         (r (apply make-typed-array *atype* *unspecified* dims)))
-    (for-indices-in-range
-      (lambda r-indices
-        (apply
-          array-set!
-          r
-          (let ((s 0))
-            (for-indices-in-range
-              (lambda free-indices
-                (set! s (+ s (* (apply
-                                  array-ref
-                                  a
-                                  (append (take r-indices nb-fix-dims-a)
-                                          free-indices))
-                                (apply
-                                  array-ref
-                                  b
-                                  (append free-indices
-                                          (drop r-indices nb-fix-dims-a)))))))
-              (list-zeros n)
-              (take dims-b n))
-            s)
-          r-indices))
-      (list-zeros (length dims))
-      dims)
+  (let* ((dims-a (array-dimensions a))
+        (dims-b (array-dimensions b))
+        (free-dims-a (take dims-a (- (array-rank a) n)))
+        (free-dims-b (drop dims-b n))
+        (bound-dims (take dims-b n))
+        (n-free-dims-a (apply * free-dims-a))
+        (n-free-dims-b (apply * free-dims-b))
+        (n-bound-dims (apply * bound-dims))
+        (s 0)
+        (r (apply make-typed-array *atype* *unspecified* (append free-dims-a
+                                                                 free-dims-b)))
+        (ac (array-contents a))
+        (bc (array-contents b))
+        (rc (array-contents r)))
+    (do-times
+     n-free-dims-a
+     (lambda (i)
+       (let ((i-k (* n-bound-dims i))
+            (i-j (* n-free-dims-b i)))
+        (do-times
+         n-free-dims-b
+         (lambda (j)
+           (set! s 0)
+           (do-times
+            n-bound-dims
+            (lambda (k)
+              (set! s (+ s (* (array-ref ac (+ i-k k))
+                              (array-ref bc (+ (* n-free-dims-b k) j)))))))
+           (array-set! rc s (+ i-j j)))))))
     r))
 
 ;;;; utilities that work on both numbers and arrays
@@ -142,7 +149,7 @@ element-wise. All arrays must have the same dimension."
   (forward internal-forward)
   (jacobian internal-jacobian))
 
-;; (define *atype* 'f64)
+;;(define *atype* 'f32)
 (define *atype* #t)
 (define *grad* (make-parameter #f))