From 43e310a093120a6bd7124ca0f426392b08d4d451 Mon Sep 17 00:00:00 2001 From: admin Date: Fri, 10 Nov 2023 16:35:53 +0900 Subject: [PATCH] Optimize `contract-arrays' using absolute indexing --- grad.scm | 71 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/grad.scm b/grad.scm index e93a312..f0aa3c2 100644 --- a/grad.scm +++ b/grad.scm @@ -7,6 +7,7 @@ (*atype* adot amap2 + contract-arrays differentiable-wrapper dot extend @@ -15,7 +16,8 @@ maximum mean rank-of - sum) + sum + ) #:replace ((i:sqrt . sqrt) (i:exp . exp) @@ -63,37 +65,42 @@ n)) (drop dims-b n))))) +(define(do-times n proc) + (let rec ((i 0)) + (unless (= i n) + (proc i) + (rec (1+ i))))) + (define (contract-arrays a b n) - (let* ((dims (contracted-dims a b n)) - (dims-b (array-dimensions b)) - (nb-dims-a (array-rank a)) - (nb-dims-b (array-rank b)) - (nb-fix-dims-a (- nb-dims-a n)) - (r (apply make-typed-array *atype* *unspecified* dims))) - (for-indices-in-range - (lambda r-indices - (apply - array-set! - r - (let ((s 0)) - (for-indices-in-range - (lambda free-indices - (set! s (+ s (* (apply - array-ref - a - (append (take r-indices nb-fix-dims-a) - free-indices)) - (apply - array-ref - b - (append free-indices - (drop r-indices nb-fix-dims-a))))))) - (list-zeros n) - (take dims-b n)) - s) - r-indices)) - (list-zeros (length dims)) - dims) + (let* ((dims-a (array-dimensions a)) + (dims-b (array-dimensions b)) + (free-dims-a (take dims-a (- (array-rank a) n))) + (free-dims-b (drop dims-b n)) + (bound-dims (take dims-b n)) + (n-free-dims-a (apply * free-dims-a)) + (n-free-dims-b (apply * free-dims-b)) + (n-bound-dims (apply * bound-dims)) + (s 0) + (r (apply make-typed-array *atype* *unspecified* (append free-dims-a + free-dims-b))) + (ac (array-contents a)) + (bc (array-contents b)) + (rc (array-contents r))) + (do-times + n-free-dims-a + (lambda (i) + (let ((i-k (* n-bound-dims i)) + (i-j (* n-free-dims-b i))) + (do-times + n-free-dims-b + (lambda (j) + (set! s 0) + (do-times + n-bound-dims + (lambda (k) + (set! s (+ s (* (array-ref ac (+ i-k k)) + (array-ref bc (+ (* n-free-dims-b k) j))))))) + (array-set! rc s (+ i-j j))))))) r)) ;;;; utilities that work on both numbers and arrays @@ -142,7 +149,7 @@ element-wise. All arrays must have the same dimension." (forward internal-forward) (jacobian internal-jacobian)) -;; (define *atype* 'f64) +;;(define *atype* 'f32) (define *atype* #t) (define *grad* (make-parameter #f)) -- 2.39.5