From 6913e383ffb721e5b387628c99797a5ea7423c84 Mon Sep 17 00:00:00 2001 From: admin Date: Sun, 5 Nov 2023 22:17:58 +0900 Subject: [PATCH] Optimize the gradient algorithm With this optimization the jacobians are computed column by column improving both space and time complexity. --- grad.scm | 450 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 261 insertions(+), 189 deletions(-) diff --git a/grad.scm b/grad.scm index 694aa55..81484a0 100644 --- a/grad.scm +++ b/grad.scm @@ -1,4 +1,5 @@ (define-module (vouivre grad) + #:use-module (ice-9 receive) #:use-module (srfi srfi-1) #:use-module (srfi srfi-9) #:use-module (vouivre misc) @@ -11,8 +12,6 @@ grad internal-jacobian mean - mirror - one rank-of) #:replace ((i:* . *) @@ -22,11 +21,6 @@ (i:identity . identity) (i:max . max))) -;;;; misc utilities - -(define (one . args) - 1) - ;;;; array utilities (define (contracted-dims a b n) @@ -104,28 +98,6 @@ element-wise. All arrays must have the same dimension." ((extend *) x y)) (else (error "can't dot because of invalid types or ranks" x y n)))) -(define (mirror f x . args) - "Create a scalar/array of shape [x]:[x] where element with index `is:is' has -the value of `f' evaluated at the `is' elements of `args' and all other elements -being zero." - (cond - ((number? x) - (apply f args)) - ((array? x) - (let ((n (array-rank x)) - (dims (array-dimensions x))) - (apply - produce-array - (lambda indices - (if (equal? (take indices n) - (drop indices n)) - (apply f (map (lambda (arg) - (apply array-ref arg (drop indices n))) - args)) - 0)) - (append dims dims)))) - (else (error "expected array or number, got " x)))) - (define (rank-of x) (if (number? x) 0 @@ -152,126 +124,219 @@ being zero." ;;;; differentiation (define-record-type internal - (make-internal jacobian forward) + (make-internal forward jacobian) internal? - (jacobian internal-jacobian set-internal-jacobian!) - (forward internal-forward)) + (forward internal-forward) + (jacobian internal-jacobian)) (define *grad* (make-parameter #f)) (define* (grad f #:optional (axis 0)) (define (wrap x i) (if (= i axis) - (make-internal (mirror one x) x) + (make-internal x #f) x)) (lambda xs - ((if (*grad*) identity internal-jacobian) - (let ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity)))) - (parameterize (((@@ (vouivre grad) *grad*) (list-ref wrapped-xs axis))) + (let* ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity))) + (y (internal-forward (list-ref wrapped-xs axis)))) + ((if (*grad*) + identity + (lambda (boxed) + (let ((fx (internal-forward boxed)) + (gfx (internal-jacobian boxed))) + (cond + ((and (number? fx) (number? y)) + (gfx '())) + ((and (number? fx) (array? y)) + (apply produce-array + (lambda js (gfx js)) + (array-dimensions y))) + ((and (array? fx) (number? y)) + (gfx '())) + (else + (let* ((dims-out (array-dimensions fx)) + (dims-in (array-dimensions y)) + (a (apply make-array 0 (append dims-out dims-in)))) + (for-indices-in-range + (lambda js + (let ((out (gfx js))) + (for-indices-in-range + (lambda is + (apply + array-set! + a + (apply array-ref out is) + (append is js))) + (list-zeros (array-rank fx)) + dims-out))) + (list-zeros (array-rank y)) + dims-in) + a) + ))))) + (parameterize (((@@ (vouivre grad) *grad*) y)) (apply f wrapped-xs)))))) -(define (grad-input) - (internal-forward (*grad*))) +(define (unbox-fwd x) + (if (internal? x) + (internal-forward x) + x)) -(define (differentiable-wrapper jacobian-generator function input . more) - ;; NOTE: Both the jacobian generator and the function act on naked inputs - ;; (numbers or arrays not inside an internal object). The generator - ;; returns a list of jacobians, each one expressing the change of the - ;; function's output when changing the corresponding input. In cases - ;; where an argument isn't meant to be differentiable its corresponding - ;; element in the list should be `#f'. +;; `n' is the number of arguments to `function'. +;; `jacobian-generators is not a `Vec' but a `List' we only use the former to +;; show its length. +;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension. +;; `I' is the type of multi-indices indexing the output of `function'. +;; `J' is the type of multi-indices indexing the input array being differentiated. +;; `Array I' is the type of arrays indexed by multi-indices of `I'. +;; `[X]' means that `X' is boxed in an internal as when returned by +;; `differentiable-wrapper' with the array being `X' and the function +;; taking a multi-index of `J' to return an array of the same shape as `X'. +;; (∷ (→ (Vec n (→ X1 ... Xn I J Number)) +;; (→ X1 ... Xn (Array I)) +;; [X1] ... [Xn] +;; (Internal (Array I) (→ J (Array I))))) +(define (differentiable-wrapper jacobian-generators function input . more) + ;; NOTE: Both the jacobian generators and the function act on naked inputs + ;; (numbers or arrays not inside an internal object). The generators + ;; additionaly take indices -- one for each dimension of what we are + ;; differentiating with respect to (the content of `(*grad*)'). A + ;; generator returns a jacobian column expressing the change of the + ;; function's output when changing `(*grad*)' along the given indices. + ;; In cases where an argument isn't meant to be differentiable its + ;; corresponding generator should be `#f'. (let* ((inputs (cons input more)) - (naked-inputs - (map - (lambda (x) - (if (internal? x) - (internal-forward x) - x)) - inputs)) + (naked-inputs (map unbox-fwd inputs)) (output (apply function naked-inputs))) - (if (*grad*) - (make-internal - (or - (fold - (lambda (Jx x prev) - (if (not (internal? x)) - prev - ((if (not prev) - identity - (lambda (x) ((extend +) prev x))) - (dot Jx - (internal-jacobian x) - (rank-of (internal-forward x)))))) - #f - (apply jacobian-generator naked-inputs) - inputs) - (zeros-out:in output (grad-input))) - output) - output))) + (ifn (*grad*) + output + (make-internal + output + (lambda (js) + (or + (fold + (lambda (jacobian-generator input prev) + (ifn (internal? input) + prev + ((ifn prev identity (lambda (x) ((extend +) prev x))) + (ifn (internal-jacobian input) + (if (number? output) + (apply jacobian-generator + (append naked-inputs js)) + (apply + produce-array + (lambda is + (apply jacobian-generator + (append naked-inputs is js))) + (array-dimensions output))) + (let ((b ((internal-jacobian input) js)) + (fwd (internal-forward input))) + (cond + ((and (number? output) (number? fwd)) + (* (apply jacobian-generator naked-inputs) + b)) + ((and (number? output) (array? fwd)) + (array-ref + (contract-arrays + (apply + produce-array + (lambda ks + (apply jacobian-generator + (append naked-inputs ks))) + (array-dimensions fwd)) + b (array-rank fwd)))) + ((and (array? output) (number? fwd)) + (apply + produce-array + (lambda is + ((extend *) + (apply jacobian-generator + (append naked-inputs is)) + b)) + (array-dimensions output))) + (else + (apply + produce-array + (lambda is + (array-ref + (contract-arrays + (apply + produce-array + (lambda ks + (apply jacobian-generator + (append naked-inputs is ks))) + (array-dimensions fwd)) + b (array-rank fwd)))) + (array-dimensions output))))))))) + #f jacobian-generators inputs) + (if (number? output) + 0 + (apply make-array 0 (array-dimensions output))))))))) + +(define (ewise1 f) + (lambda (x . indices) + (if (number? x) + (f x) + (receive (is js) (split-at indices (array-rank x)) + (if (equal? is js) + (f (apply array-ref x is)) + 0))))) + +(define (ewise2 proc axis) + (lambda (x y . indices) + (cond + ((and (number? x) (number? y)) + (proc x y)) + ((and (number? x) (array? y)) + (if (= axis 0) + (proc x (apply array-ref y indices)) + (receive (is js) (split-at indices (array-rank y)) + (ifn (equal? is js) + 0 + (proc x (apply array-ref y is)))))) + ((and (array? x) (number? y)) + (if (= axis 1) + (proc (apply array-ref x indices) y) + (receive (is js) (split-at indices (array-rank x)) + (ifn (equal? is js) + 0 + (proc (apply array-ref x is) + y))))) + (else + (receive (is js) (split-at indices (array-rank x)) + (ifn (equal? is js) + 0 + (proc (apply array-ref x is) + (apply array-ref y is)))))))) (define (i:identity x) "Differentiable identity." (differentiable-wrapper - (lambda (x) (list (mirror one x))) + (list (ewise1 (lambda _ 1))) identity x)) (define (i:exp x) "Differentiable exponential." (differentiable-wrapper - (lambda (x) (list (mirror exp x x))) + (list (ewise1 exp)) (extend exp) x)) (define (i:* x y) "Differentiable element-wise multiplication." (differentiable-wrapper - (lambda (x y) - (list - (cond - ((and (number? y) (number? x)) - y) - ((and (number? y) (array? x)) - (mirror (lambda _ y) x)) - ((and (array? y) (number? x)) - y) - (else - (mirror identity x y))) - (cond - ((and (number? x) (number? y)) - x) - ((and (number? x) (array? y)) - (mirror (lambda _ x) y)) - ((and (array? x) (number? y)) - x) - (else - (mirror identity y x))))) + (list + (ewise2 (lambda (x y) y) 0) + (ewise2 (lambda (x y) x) 1)) (extend *) x y)) (define (i:- x y) "Differentiable element-wise subtraction." (differentiable-wrapper - (lambda (x y) - (list - (cond - ((and (number? y) (number? x)) - +1) - ((and (number? y) (array? x)) - (mirror one x)) - ((and (array? y) (number? x)) - (apply make-array 1 (array-dimensions y))) - (else - (mirror one x))) - (cond - ((and (number? x) (number? y)) - -1) - ((and (number? x) (array? y)) - (mirror (lambda _ -1) y)) - ((and (array? x) (number? y)) - (apply make-array -1 (array-dimensions x))) - (else - (mirror (lambda _ -1) y))))) + (list + (ewise2 (lambda _ +1) 0) + (ewise2 (lambda _ -1) 1)) (extend -) x y)) @@ -286,39 +351,18 @@ being zero." (else 0))) (differentiable-wrapper - (lambda (x y) - (list - (cond - ((and (number? y) (number? x)) - (dmax x y)) - ((and (number? y) (array? x)) - (mirror (lambda (x) (dmax x y)) x x)) - ((and (array? y) (number? x)) - (array-map (lambda (y) (dmax x y)) y)) - (else - (mirror (lambda (x y) (dmax x y)) x x y))) - (cond - ((and (number? x) (number? y)) - (dmax y x)) - ((and (number? x) (array? y)) - (mirror (lambda (y) (dmax y x)) y y)) - ((and (array? x) (number? y)) - (array-map (lambda (x) (dmax y x)) x)) - (else - (mirror (lambda (y x) (dmax y x)) y y x))))) + (list + (ewise2 dmax 0) + (ewise2 (flip dmax) 1)) (extend max) x y)) (define (mean x) "Differentiable mean on arrays." (differentiable-wrapper - (lambda (x) - (let ((dims (array-dimensions x))) - (list - (apply - make-array - (/ 1 (apply * dims)) - dims)))) + (list + (lambda (x . indices) + (/ 1 (apply * (array-dimensions x))))) (lambda (x) (let ((sum 0) (count 0)) @@ -330,25 +374,50 @@ being zero." (/ sum count))) x)) +;; ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30)) +;; (let ((x #2((1 2) (3 4) (5 6))) (y #2((1 2) (3 4) (5 6)))) ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))) (define (amap2 f x y) (define (unbox-with proc x) (ifn (internal? x) x (proc x))) +(define (dims-of x) + (if (number? x) + '() + (array-dimensions x))) (define (boxed-ref x i) (ifn (internal? x) (array-cell-ref x i) - (make-internal (array-cell-ref (internal-jacobian x) + (make-internal (array-cell-ref (internal-forward x) i) - (array-cell-ref (internal-forward x) - i)))) + (if (internal-jacobian x) + (lambda (js) + (array-cell-ref + ((internal-jacobian x) + js) + i)) + (lambda (js) + (let* ((x (internal-forward x)) + (xi (array-cell-ref x i))) + (if (number? xi) + (if (= i (car js)) + 1 + 0) + (let ((a (apply make-array 0 (dims-of xi)))) + (for-indices-in-range + (lambda out + (when (and (= i (car js)) + (equal? out (cdr js))) + (apply + array-set! + a 1 + out))) + (list-zeros (rank-of xi)) + (dims-of xi)) + a)))))))) (define (boxed-fi f i x y) (f (boxed-ref x i) (boxed-ref y i))) - (define (dims-of x) - (if (number? x) - '() - (array-dimensions x))) (if (internal? f) (error "unsupported application of `amap2' to internal object" f) (let ((bs (first (array-dimensions (unbox-with internal-forward x)))) @@ -360,54 +429,57 @@ being zero." (array-cell-set! fwd (boxed-fi f i x y) i)) '(0) (list bs)) fwd) - (let ((jac (apply make-array 0 bs (dims-of (internal-jacobian f0)))) + (let ((jac (make-array 0 bs)) (fwd (apply make-array 0 bs (dims-of (internal-forward f0))))) (for-indices-in-range (lambda (i) - (let ((fi (boxed-fi f i x y))) - (array-cell-set! jac (internal-jacobian fi) i) + (let* ((fi (boxed-fi f i x y)) + (Jfi (internal-jacobian fi))) + (array-set! jac (internal-jacobian fi) i) (array-cell-set! fwd (internal-forward fi) i))) '(0) (list bs)) - (make-internal jac fwd)))))) + (make-internal + fwd + (lambda (js) + (let ((a (apply make-array 0 bs (dims-of (internal-forward f0))))) + (for-indices-in-range + (lambda (batch) + (array-cell-set! + a ((array-ref jac batch) js) + batch)) + '(0) (list bs)) + a)))))))) (define (adot x y n) (differentiable-wrapper - (lambda (x y n) - (let ((bound-dims (contracted-dims x y n))) - (list - (let ((r (apply make-array 0 (append bound-dims (array-dimensions x))))) - (for-indices-in-range - (lambda free-indices - (let ((free-indices-x (take free-indices (- (array-rank x) n))) - (free-indices-y (drop free-indices (- (array-rank x) n)))) - (for-indices-in-range - (lambda bound-indices - (apply - array-set! - r - (apply array-ref y (append bound-indices free-indices-y)) - (append free-indices free-indices-x bound-indices))) - (list-zeros n) - (take (array-dimensions y) n)))) - (list-zeros (length bound-dims)) - bound-dims) - r) - (let ((r (apply make-array 0 (append bound-dims (array-dimensions y))))) - (for-indices-in-range - (lambda free-indices - (let ((free-indices-x (take free-indices (- (array-rank x) n))) - (free-indices-y (drop free-indices (- (array-rank x) n)))) - (for-indices-in-range - (lambda bound-indices - (apply - array-set! - r - (apply array-ref x (append free-indices-x bound-indices)) - (append free-indices bound-indices free-indices-y))) - (list-zeros n) - (take (array-dimensions y) n)))) - (list-zeros (length bound-dims)) - bound-dims) - r) - #f))) + (list + (lambda (x y n . indices) + (let* ((free-rank-x (- (array-rank x) + n)) + (free-rank-y (- (array-rank y) + n)) + (out (take indices (+ free-rank-x free-rank-y))) + (in (drop indices (+ free-rank-x free-rank-y))) + (is-free1 (take out free-rank-x)) + (js-free (drop out free-rank-x)) + (is-free2 (take in free-rank-x)) + (is-bound (drop in free-rank-x))) + (ifn (equal? is-free1 is-free2) + 0 + (apply array-ref y (append is-bound js-free))))) + (lambda (x y n . indices) + (let* ((free-rank-x (- (array-rank x) + n)) + (free-rank-y (- (array-rank y) + n)) + (out (take indices (+ free-rank-x free-rank-y))) + (in (drop indices (+ free-rank-x free-rank-y))) + (is-free (take out free-rank-x)) + (js-free1 (drop out free-rank-x)) + (js-free2 (drop in n)) + (js-bound (take in n))) + (ifn (equal? js-free1 js-free2) + 0 + (apply array-ref x (append is-free js-bound))))) + #f) contract-arrays x y n)) -- 2.39.2