From f84c38e99cb33e1ef6aa36963c408ca686b0a322 Mon Sep 17 00:00:00 2001 From: admin Date: Thu, 9 Nov 2023 13:18:54 +0900 Subject: [PATCH] Precompute values for jacobian generators --- grad.scm | 142 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 77 insertions(+), 65 deletions(-) diff --git a/grad.scm b/grad.scm index 46b1a99..8f526fd 100644 --- a/grad.scm +++ b/grad.scm @@ -40,6 +40,19 @@ ;;;; array utilities +(define (abs->rel index dimensions) + (let rec ((ds (cdr dimensions)) + (p (apply * (cdr dimensions))) + (r index) + (is '())) + (let ((i (quotient r p))) + (if (null? ds) + (reverse (cons i is)) + (rec (cdr ds) + (quotient p (car ds)) + (- r (* p i)) + (cons i is)))))) + (define (contracted-dims a b n) (let ((dims-a (array-dimensions a)) (dims-b (array-dimensions b))) @@ -208,7 +221,14 @@ element-wise. All arrays must have the same dimension." ;; corresponding generator should be `#f'. (let* ((inputs (cons input more)) (naked-inputs (map unbox-fwd inputs)) - (output (apply function naked-inputs))) + (output (apply (if (procedure? function) + function + (car function)) + naked-inputs)) + (data (if (procedure? function) + '() + (map (lambda (f) (apply f naked-inputs)) + (cdr function))))) (ifn (*grad*) output (make-internal @@ -222,18 +242,22 @@ element-wise. All arrays must have the same dimension." ((ifn prev identity (lambda (x) ((extend +) prev x))) (ifn (internal-jacobian input) (if (number? output) - (jacobian-generator naked-inputs js) + (apply jacobian-generator naked-inputs js data) (apply produce-typed-array (lambda is - (jacobian-generator naked-inputs (append is js))) + (apply jacobian-generator + naked-inputs + (append is js) + data)) *atype* (array-dimensions output))) (let ((b ((internal-jacobian input) js)) (fwd (internal-forward input))) (cond ((and (number? output) (number? fwd)) - (* (jacobian-generator naked-inputs '()) + (* (apply jacobian-generator + naked-inputs '() data) b)) ((and (number? output) (array? fwd)) (array-ref @@ -241,7 +265,8 @@ element-wise. All arrays must have the same dimension." (apply produce-typed-array (lambda ks - (jacobian-generator naked-inputs ks)) + (apply jacobian-generator + naked-inputs ks data)) *atype* (array-dimensions fwd)) b (array-rank fwd)))) @@ -250,7 +275,8 @@ element-wise. All arrays must have the same dimension." produce-typed-array (lambda is ((extend *) - (jacobian-generator naked-inputs is) + (apply jacobian-generator + naked-inputs is data) b)) *atype* (array-dimensions output))) @@ -263,7 +289,10 @@ element-wise. All arrays must have the same dimension." (apply produce-typed-array (lambda ks - (jacobian-generator naked-inputs (append is ks))) + (apply jacobian-generator + naked-inputs + (append is ks) + data)) *atype* (array-dimensions fwd)) b (array-rank fwd)))) @@ -457,17 +486,19 @@ element-wise. All arrays must have the same dimension." "Differentiable mean on arrays." (differentiable-wrapper (list - (lambda (xs indices) - (/ 1 (apply * (array-dimensions (car xs)))))) - (lambda (x) - (let ((sum 0) - (count 0)) - (array-for-each - (lambda (x) - (set! sum (+ sum x)) - (set! count (1+ count))) - x) - (/ sum count))) + (lambda (xs indices one-over-n) + one-over-n)) + (let ((n 0)) + (list + (lambda (x) + (let ((sum 0)) + (array-for-each + (lambda (x) + (set! sum (+ sum x)) + (set! n (1+ n))) + x) + (/ sum n))) + (lambda _ (/ n)))) x)) (define (i:array-ref x . indices) @@ -483,48 +514,35 @@ element-wise. All arrays must have the same dimension." array-ref x indices)) -(define (for-flat-index proc a) - (let* ((b (array-contents a)) - (n (array-length b))) - (let rec ((i 0)) - (unless (= i n) - (proc b i) - (rec (1+ i)))))) - (define (maximum x) "Differentiable maximum on arrays." (differentiable-wrapper (list - (lambda (xs indices) - (let ((x (car xs)) - (m (- (inf))) - (i 'TBD)) - (for-indices-in-range - (lambda indices - (let ((xi (apply array-ref x indices))) - (when (< m xi) - (set! m xi) - (set! i indices)))) - (list-zeros (array-rank x)) - (array-dimensions x)) - (if (equal? i indices) - 1 - 0)))) - (lambda (x) - (let ((m (- (inf)))) - (array-for-each - (lambda (x) - (set! m (max m x))) - x) - m)) + (lambda (xs indices max-indices) + (if (equal? indices max-indices) + 1 + 0))) + (let ((max-index 'TBD)) + (list + (lambda (x) + (let ((m (- (inf))) + (i 0)) + (array-for-each + (lambda (x) + (when (< m x) + (set! m x) + (set! max-index i)) + (set! i (1+ i))) + x) + m)) + (lambda (x) + (abs->rel max-index (array-dimensions x))))) x)) (define (sum x) "Differentiable sum on arrays." (differentiable-wrapper - (list - (lambda (xs indices) - 1)) + (list (lambda _ 1)) (lambda (x) (let ((sum 0)) (array-for-each @@ -616,14 +634,9 @@ element-wise. All arrays must have the same dimension." (define (adot x y n) (differentiable-wrapper (list - (lambda (xs indices) - (let* ((x (car xs)) - (y (cadr xs)) + (lambda (xs indices free-rank-x free-rank-y) + (let* ((y (cadr xs)) (n (caddr xs)) - (free-rank-x (- (array-rank x) - n)) - (free-rank-y (- (array-rank y) - n)) (out (take indices (+ free-rank-x free-rank-y))) (in (drop indices (+ free-rank-x free-rank-y))) (is-free1 (take out free-rank-x)) @@ -633,14 +646,9 @@ element-wise. All arrays must have the same dimension." (ifn (equal? is-free1 is-free2) 0 (apply array-ref y (append is-bound js-free))))) - (lambda (xs indices) + (lambda (xs indices free-rank-x free-rank-y) (let* ((x (car xs)) - (y (cadr xs)) (n (caddr xs)) - (free-rank-x (- (array-rank x) - n)) - (free-rank-y (- (array-rank y) - n)) (out (take indices (+ free-rank-x free-rank-y))) (in (drop indices (+ free-rank-x free-rank-y))) (is-free (take out free-rank-x)) @@ -651,4 +659,8 @@ element-wise. All arrays must have the same dimension." 0 (apply array-ref x (append is-free js-bound))))) #f) - contract-arrays x y n)) + (list + contract-arrays + (lambda (x y n) (- (array-rank x) n)) + (lambda (x y n) (- (array-rank y) n))) + x y n)) -- 2.39.5