From 6f945ff1f9cae705f27d70ec18891df3ce3b50d4 Mon Sep 17 00:00:00 2001 From: admin Date: Mon, 20 Nov 2023 16:56:37 +0900 Subject: [PATCH] Optimize garbage collection Instead of allocating memory on every index of every differentiable function call we do it once per call and use the buffer for all indices. --- grad.scm | 658 +++++++++++++++++++++++++++++-------------------------- misc.scm | 16 +- 2 files changed, 350 insertions(+), 324 deletions(-) diff --git a/grad.scm b/grad.scm index f0aa3c2..9e4ec76 100644 --- a/grad.scm +++ b/grad.scm @@ -3,6 +3,7 @@ #:use-module (srfi srfi-1) #:use-module (srfi srfi-9) #:use-module (vouivre misc) + #:use-module (vouivre promises) #:export (*atype* adot @@ -10,14 +11,17 @@ contract-arrays differentiable-wrapper dot + do-times + ewise1 + ewise2 extend grad - internal-jacobian + make-batch + make-internal maximum mean rank-of - sum - ) + sum) #:replace ((i:sqrt . sqrt) (i:exp . exp) @@ -35,35 +39,24 @@ (i:abs . abs) (i:identity . identity) (i:array-ref . array-ref) - ) + (i:array-cell-ref . array-cell-ref)) #:re-export (fold reduce)) ;;;; array utilities -(define (abs->rel index dimensions) - (let rec ((ds (cdr dimensions)) - (p (apply * (cdr dimensions))) - (r index) - (is '())) - (let ((i (quotient r p))) - (if (null? ds) - (reverse (cons i is)) - (rec (cdr ds) - (quotient p (car ds)) - (- r (* p i)) - (cons i is)))))) - -(define (contracted-dims a b n) - (let ((dims-a (array-dimensions a)) - (dims-b (array-dimensions b))) - (if (or (> n (array-rank a)) - (> n (array-rank b))) - (error "can't contract arrays with size lower than" n) - (append (take dims-a (- (array-rank a) - n)) - (drop dims-b n))))) +(define (rel->abs indices dimensions) + (let rec ((s 0) + (p 1) + (is (reverse indices)) + (ds (reverse dimensions))) + (if (null? is) + s + (rec (+ s (* p (car is))) + (* p (car ds)) + (cdr is) + (cdr ds))))) (define(do-times n proc) (let rec ((i 0)) @@ -152,176 +145,232 @@ element-wise. All arrays must have the same dimension." ;;(define *atype* 'f32) (define *atype* #t) (define *grad* (make-parameter #f)) +(define *j* (make-parameter #f)) + +(define-syntax-rule (w/j val body ...) + (parameterize ((*j* val)) + body ...)) + +(define (unwrap-fwd x) + (if (internal? x) + (internal-forward x) + x)) + +(define (unwrap-jac x) + (if (internal? x) + (internal-jacobian x) + x)) + +(define (dims-of x) + (if (number? x) + '() + (array-dimensions x))) + +(define (mov dst-buf n-dst-dims generator naked-inputs data j) + (do-times + n-dst-dims + (lambda (i) + (array-set! + dst-buf + (apply generator naked-inputs i j data) + i))) + dst-buf) + +(define (add dst-buf n-dst-dims generator naked-inputs data j) + (do-times + n-dst-dims + (lambda (i) + (array-set! + dst-buf + (+ (array-ref dst-buf i) + (apply generator naked-inputs i j data)) + i))) + dst-buf) + +(define (movc dst-buf n-dst-dims src-buf n-src-dims + generator naked-inputs data) + "Contract the Jacobian column produced by the generator with the source buffer +storing the result in the destination buffer." + (let ((s 0)) + (do-times + n-dst-dims + (lambda (i) + (set! s 0) + (do-times + n-src-dims + (lambda (k) + (set! s (+ s (* (apply generator naked-inputs i k data) + (array-ref src-buf k)))))) + (array-set! dst-buf s i)))) + dst-buf) + +(define (addc dst-buf n-dst-dims src-buf n-src-dims + generator naked-inputs data) + "Contract the Jacobian column produced by the generator with the source buffer +adding the result to the destination buffer." + (let ((s 0)) + (do-times + n-dst-dims + (lambda (i) + (set! s (array-ref dst-buf i)) + (do-times + n-src-dims + (lambda (k) + (set! s (+ s (* (apply generator naked-inputs i k data) + (array-ref src-buf k)))))) + (array-set! dst-buf s i)))) + dst-buf) (define* (grad f #:optional (axis 0)) (define (wrap x i) (if (= i axis) - (make-internal x #f) + (make-internal x 'input) x)) (lambda xs - (let* ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity))) - (y (internal-forward (list-ref wrapped-xs axis)))) - ((if (*grad*) - identity - (lambda (boxed) - (let ((fx (internal-forward boxed)) - (gfx (internal-jacobian boxed))) - (cond - ((and (number? fx) (number? y)) - (gfx '())) - ((and (number? fx) (array? y)) - (apply produce-typed-array - (lambda js (gfx js)) - *atype* - (array-dimensions y))) - ((and (array? fx) (number? y)) - (gfx '())) - (else - (let* ((dims-out (array-dimensions fx)) - (dims-in (array-dimensions y)) - (a (apply make-typed-array *atype* *unspecified* (append dims-out dims-in)))) - (for-indices-in-range - (lambda js - (let ((out (gfx js))) - (for-indices-in-range - (lambda is - (apply - array-set! - a - (apply array-ref out is) - (append is js))) - (list-zeros (array-rank fx)) - dims-out))) - (list-zeros (array-rank y)) - dims-in) - a)))))) - (parameterize (((@@ (vouivre grad) *grad*) y)) - (apply f wrapped-xs)))))) - -(define (unbox-fwd x) - (if (internal? x) - (internal-forward x) - x)) - -;; `n' is the number of arguments to `function'. -;; `jacobian-generators is not a `Vec' but a `List' we only use the former to -;; show its length. + (parameterize (((@@ (vouivre grad) *grad*) #t) + ((@@ (vouivre grad) *promises*) (cons '() #f))) + (let* ((internal (apply f (map-indexed wrap xs))) + (fx (internal-forward internal)) + (y (list-ref xs axis)) ; variable to differentiate w.r.t + (pre-Jx (internal-jacobian internal)) + (Jx (cond + ;; TODO: implement 'input case and test 'zero and 'input + ((eq? pre-Jx 'zero) + (lambda (j) + (lambda (i) + 0))) + ((eq? pre-Jx 'input) + (error "TBD.")) + (else + (lambda (j) + (reset-promises (car (*promises*))) + (let ((column-jac (w/j j (force pre-Jx)))) + (lambda (i) + (array-ref column-jac i)))))))) + (cond + ((and (number? fx) (number? y)) + ((Jx 0) 0)) + ((and (number? fx) (array? y)) + (let* ((y-dims (array-dimensions y)) + (a (apply make-array *unspecified* y-dims)) + (ac (array-contents a))) + (do-times + (apply * y-dims) + (lambda (j) + (array-set! ac ((Jx j) 0) + j))) + a)) + ((and (array? fx) (number? y)) + (let* ((fx-dims (array-dimensions fx)) + (a (apply make-array *unspecified* fx-dims)) + (ac (array-contents a)) + (Jx (Jx 0))) + (do-times + (apply * fx-dims) + (lambda (i) + (array-set! ac (Jx i) + i))) + a)) + (else + (let* ((fx-dims (array-dimensions fx)) + (y-dims (array-dimensions y)) + (n-fx-dims (apply * fx-dims)) + (n-y-dims (apply * y-dims)) + (a (apply make-array *unspecified* (append fx-dims y-dims))) + (ac (array-contents a))) + (do-times + n-y-dims + (lambda (j) + (let ((Jx (Jx j))) + (do-times + n-fx-dims + (lambda (i) + (array-set! ac (Jx i) + (+ j (* n-y-dims i)))))))) + a))))))) + + +;; In the comment that follows: +;; +;; `n' is the number of arguments to `proc'. +;; `generators is not a `Vec' but a `List' we only use the former to illustrate +;; its length. ;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension. ;; `I' is the type of multi-indices indexing the output of `function'. ;; `J' is the type of multi-indices indexing the input array being differentiated. +;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J'). ;; `Array I' is the type of arrays indexed by multi-indices of `I'. ;; `[X]' means that `X' is boxed in an internal as when returned by -;; `differentiable-wrapper' with the array being `X' and the function -;; taking a multi-index of `J' to return an array of the same shape as `X'. -;; (∷ (→ (Vec n (→ X1 ... Xn I J Number)) -;; (→ X1 ... Xn (Array I)) +;; `differentiable-wrapper' with the array being `X' and the promise that +;; given a |J| we will get the change of `X' with a change of the +;; the differentiated argument at multi-index `J'. +;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number)) +;; (→ X1 ... Xn (Array I))* ;; [X1] ... [Xn] -;; (Internal (Array I) (→ J (Array I))))) -(define (differentiable-wrapper jacobian-generators function input . more) - ;; NOTE: Both the jacobian generators and the function act on naked inputs - ;; (numbers or arrays not inside an internal object). The generators - ;; additionaly take indices -- one for each dimension of what we are - ;; differentiating with respect to (the content of `(*grad*)'). A - ;; generator returns a jacobian column expressing the change of the - ;; function's output when changing `(*grad*)' along the given indices. - ;; In cases where an argument isn't meant to be differentiable its - ;; corresponding generator should be `#f'. - (let* ((inputs (cons input more)) - (naked-inputs (map unbox-fwd inputs)) - (output (apply (if (procedure? function) - function - (car function)) - naked-inputs)) - (data (if (procedure? function) - '() - (map (lambda (f) (apply f naked-inputs)) - (cdr function))))) +;; (Internal (Array I) (Promise |J| (Array |I|))))) +;; +;; (*) We extend this definition to allow `proc' to be a list of procedures +;; the head of which is as described above and the remaining elements +;; are procedures of the same arguments but returning values that are +;; then fed as extra data to the generators. +;; +;; NOTE: In cases where an argument isn't meant to be differentiable its +;; corresponding generator should be `#f'. +(define (differentiable-wrapper generators proc* arg . more) + (let* ((args (cons arg more)) + (proc (if (procedure? proc*) + proc* + (car proc*))) + (naked-args (map unwrap-fwd args)) + (out (apply proc naked-args))) (ifn (*grad*) - output - (make-internal - output - (lambda (js) - (or - (fold - (lambda (jacobian-generator input prev) - (ifn (internal? input) - prev - ((ifn prev identity (lambda (x) ((extend +) prev x))) - (ifn (internal-jacobian input) - (if (number? output) - (apply jacobian-generator naked-inputs js data) - (apply - produce-typed-array - (lambda is - (apply jacobian-generator - naked-inputs - (append is js) - data)) - *atype* - (array-dimensions output))) - (let ((b ((internal-jacobian input) js)) - (fwd (internal-forward input))) - (cond - ((and (number? output) (number? fwd)) - (* (apply jacobian-generator - naked-inputs '() data) - b)) - ((and (number? output) (array? fwd)) - (array-ref - (contract-arrays - (apply - produce-typed-array - (lambda ks - (apply jacobian-generator - naked-inputs ks data)) - *atype* - (array-dimensions fwd)) - b (array-rank fwd)))) - ((and (array? output) (number? fwd)) - (apply - produce-typed-array - (lambda is - ((extend *) - (apply jacobian-generator - naked-inputs is data) - b)) - *atype* - (array-dimensions output))) - (else - (apply - produce-typed-array - (lambda is - (array-ref - (contract-arrays - (apply - produce-typed-array - (lambda ks - (apply jacobian-generator - naked-inputs - (append is ks) - data)) - *atype* - (array-dimensions fwd)) - b (array-rank fwd)))) - *atype* - (array-dimensions output))))))))) - #f jacobian-generators inputs) - (if (number? output) - 0 - (apply make-typed-array *atype* 0 (array-dimensions output))))))))) + out + (let* ((data (if (procedure? proc*) + '() + (map (lambda (g) + (apply g naked-args)) + (cdr proc*)))) + (n-out-dims (apply * (dims-of out))) + (buf (make-array *unspecified* n-out-dims))) + (make-internal + out + (fold + (lambda (generator arg prev) + (if (or (not (internal? arg)) + (eq? 'zero (internal-jacobian arg))) + prev + (let ((Jx (internal-jacobian arg)) + (n-fwd-dims (apply * (dims-of (unwrap-fwd arg))))) + (if (eq? Jx 'input) + (if (eq? prev 'zero) + (delay + (mov buf n-out-dims + generator naked-args data (*j*))) + (delay + (add (force prev) n-out-dims + generator naked-args data (*j*)))) + (if (eq? prev 'zero) + (delay + (movc buf n-out-dims (force Jx) n-fwd-dims + generator naked-args data)) + (delay + (addc (force prev) n-out-dims + (force Jx) n-fwd-dims + generator naked-args data))))))) + 'zero generators args)))))) (define (ewise1 f) - (lambda (xs indices) + (lambda (xs i j) (let ((x (car xs))) (if (number? x) (f x) - (receive (is js) (split-at indices (array-rank x)) - (if (equal? is js) - (f (apply array-ref x is)) - 0)))))) + (ifn (= i j) + 0 + (f (array-ref (array-contents x) + j))))))) (define (ewise2 proc axis) - (lambda (xs indices) + (lambda (xs i j) (let ((x (car xs)) (y (cadr xs))) (cond @@ -329,25 +378,29 @@ element-wise. All arrays must have the same dimension." (proc x y)) ((and (number? x) (array? y)) (if (= axis 0) - (proc x (apply array-ref y indices)) - (receive (is js) (split-at indices (array-rank y)) - (ifn (equal? is js) - 0 - (proc x (apply array-ref y is)))))) + (proc x (array-ref (array-contents y) + i)) + (ifn (= i j) + 0 + (proc x (array-ref (array-contents y) + j))))) ((and (array? x) (number? y)) (if (= axis 1) - (proc (apply array-ref x indices) y) - (receive (is js) (split-at indices (array-rank x)) - (ifn (equal? is js) - 0 - (proc (apply array-ref x is) - y))))) + (proc (array-ref (array-contents x) + i) + y) + (ifn (= i j) + 0 + (proc (array-ref (array-contents x) + j) + y)))) (else - (receive (is js) (split-at indices (array-rank x)) - (ifn (equal? is js) - 0 - (proc (apply array-ref x is) - (apply array-ref y is))))))))) + (ifn (= i j) + 0 + (proc (array-ref (array-contents x) + j) + (array-ref (array-contents y) + j)))))))) (define (i:identity x) "Differentiable identity." @@ -493,19 +546,19 @@ element-wise. All arrays must have the same dimension." "Differentiable mean on arrays." (differentiable-wrapper (list - (lambda (xs indices one-over-n) + (lambda (xs i j one-over-n) one-over-n)) (let ((n 0)) - (list - (lambda (x) - (let ((sum 0)) - (array-for-each - (lambda (x) - (set! sum (+ sum x)) - (set! n (1+ n))) - x) - (/ sum n))) - (lambda _ (/ n)))) + (list + (lambda (x) + (let ((sum 0)) + (array-for-each + (lambda (x) + (set! sum (+ sum x)) + (set! n (1+ n))) + x) + (/ sum n))) + (lambda _ (/ n)))) x)) (define (i:array-ref x . indices) @@ -513,20 +566,72 @@ element-wise. All arrays must have the same dimension." (apply differentiable-wrapper (cons - (lambda (xs js) - (if (equal? js indices) + (lambda (xs i j abs-index) + (if (= j abs-index) 1 0)) - (list-tabulate (length indices) not)) - array-ref + (map not indices)) + (list + array-ref + (lambda (x . indices) + (rel->abs indices (array-dimensions x)))) + x indices)) + +(define (i:array-cell-ref x . indices) + (apply + differentiable-wrapper + (cons + (lambda (xs i j abs-index n-rst-dims) + (receive (j-ref j-rst) (euclidean/ j n-rst-dims) + (if (and (= j-ref abs-index) + (= j-rst i)) + 1 + 0))) + (map not indices)) + (list + array-cell-ref + (lambda (x . indices) + (rel->abs indices (take (array-dimensions x) + (length indices)))) + (lambda (x . indices) + (apply * (drop (array-dimensions x) + (length indices))))) x indices)) +(define (make-batch elem . more) + (let ((batch-size (1+ (length more)))) + (apply + differentiable-wrapper + (list-tabulate + batch-size + (lambda (b) + (lambda (xs i j n-rest-dims) + (receive (i-batch i-rest) (euclidean/ i n-rest-dims) + (if (and (= i-batch b) + (= i-rest j)) + 1 + 0 + ))))) + (list + (lambda (elem . more) + (let ((a (apply make-typed-array *atype* *unspecified* batch-size + (dims-of elem)))) + (for-each + (lambda (x b) + (array-cell-set! a x b)) + (cons elem more) + (list-tabulate batch-size identity)) + a)) + (lambda (elem . more) + (apply * (dims-of elem)))) + elem more))) + (define (maximum x) "Differentiable maximum on arrays." (differentiable-wrapper (list - (lambda (xs indices max-indices) - (if (equal? indices max-indices) + (lambda (xs i j max-index) + (if (= j max-index) 1 0))) (let ((max-index 'TBD)) @@ -542,8 +647,7 @@ element-wise. All arrays must have the same dimension." (set! i (1+ i))) x) m)) - (lambda (x) - (abs->rel max-index (array-dimensions x))))) + (lambda _ max-index))) x)) (define (sum x) @@ -559,115 +663,37 @@ element-wise. All arrays must have the same dimension." sum)) x)) -(define (amap2 f x y) - (define (unbox-with proc x) - (ifn (internal? x) - x - (proc x))) - (define (dims-of x) - (if (number? x) - '() - (array-dimensions x))) - (define (boxed-ref x i) - (let ((xi (array-cell-ref (unbox-with internal-forward x) - i))) - (ifn (internal? x) - xi - (make-internal - xi - (if (internal-jacobian x) - (lambda (js) - (array-cell-ref - ((internal-jacobian x) - js) - i)) - (lambda (js) - (if (number? xi) - (if (= i (car js)) - 1 - 0) - (apply - produce-typed-array - (lambda indices - (if (and (= i (car js)) - (equal? indices (cdr js))) - 1 - 0)) - (array-type xi) - (array-dimensions xi))))))))) - (define (boxed-fi f i x y) - (f (boxed-ref x i) - (boxed-ref y i))) - (if (internal? f) - (error "unsupported application of `amap2' to internal object" f) - (let ((bs (first (array-dimensions (unbox-with internal-forward x)))) - (f0 (boxed-fi f 0 x y))) - (ifn (internal? f0) - (let ((fwd (apply make-typed-array - ;; TODO: use the correct type based on f0. - *atype* *unspecified* bs (dims-of f0)))) - (for-indices-in-range - (lambda (i) - (array-cell-set! fwd (boxed-fi f i x y) i)) - '(0) (list bs)) - fwd) - (let ((jac (make-array *unspecified* bs)) - (fwd (apply make-typed-array - ;; TODO: use the correct type based on f0. - *atype* *unspecified* - bs (dims-of (internal-forward f0))))) - (for-indices-in-range - (lambda (i) - (let* ((fi (boxed-fi f i x y)) - (Jfi (internal-jacobian fi))) - (array-set! jac (internal-jacobian fi) i) - (array-cell-set! fwd (internal-forward fi) i))) - '(0) (list bs)) - (make-internal - fwd - (lambda (js) - (let ((a (apply make-typed-array - ;; TODO: use the correct type based on f0. - *atype* *unspecified* - bs (dims-of (internal-forward f0))))) - (for-indices-in-range - (lambda (batch) - (array-cell-set! - a ((array-ref jac batch) js) - batch)) - '(0) (list bs)) - a)))))))) - (define (adot x y n) (differentiable-wrapper (list - (lambda (xs indices free-rank-x free-rank-y) - (let* ((y (cadr xs)) - (n (caddr xs)) - (out (take indices (+ free-rank-x free-rank-y))) - (in (drop indices (+ free-rank-x free-rank-y))) - (is-free1 (take out free-rank-x)) - (js-free (drop out free-rank-x)) - (is-free2 (take in free-rank-x)) - (is-bound (drop in free-rank-x))) - (ifn (equal? is-free1 is-free2) - 0 - (apply array-ref y (append is-bound js-free))))) - (lambda (xs indices free-rank-x free-rank-y) - (let* ((x (car xs)) - (n (caddr xs)) - (out (take indices (+ free-rank-x free-rank-y))) - (in (drop indices (+ free-rank-x free-rank-y))) - (is-free (take out free-rank-x)) - (js-free1 (drop out free-rank-x)) - (js-free2 (drop in n)) - (js-bound (take in n))) - (ifn (equal? js-free1 js-free2) - 0 - (apply array-ref x (append is-free js-bound))))) + (lambda (xs i j n-free-dims-y n-bound-dims) + (receive (i-x i-y) (euclidean/ i n-free-dims-y) + (receive (j-free j-bound) (euclidean/ j n-bound-dims) + (ifn (= i-x j-free) + 0 + (array-ref (array-contents (cadr xs)) + (+ i-y (* n-free-dims-y j-bound))))))) + (lambda (xs i j n-free-dims-y n-bound-dims) + (receive (i-x i-y) (euclidean/ i n-free-dims-y) + (receive (j-bound j-free) (euclidean/ j n-free-dims-y) + (ifn (= i-y j-free) + 0 + (array-ref (array-contents (car xs)) + (+ j-bound (* n-bound-dims i-x))))))) #f) (list contract-arrays - (lambda (x y n) (- (array-rank x) n)) - (lambda (x y n) (- (array-rank y) n))) + (lambda (x y n) + (apply * (drop (array-dimensions y) + n))) + (lambda (x y n) + (apply * (take (array-dimensions y) + n)))) x y n)) + +(define (amap2 f x y) + (apply make-batch + (list-tabulate (car (dims-of (unwrap-fwd x))) + (lambda (b) + (f (i:array-cell-ref x b) + (i:array-cell-ref y b)))))) diff --git a/misc.scm b/misc.scm index 0437d7c..88f0073 100644 --- a/misc.scm +++ b/misc.scm @@ -46,14 +46,14 @@ order." (define (for-indices-in-range f starts ends) (define (for-indices-in-range% f indices starts ends) (if (null? starts) - (apply f (reverse indices)) - (do ((i (car starts) (1+ i))) - ((= i (car ends))) - (for-indices-in-range% - f - (cons i indices) - (cdr starts) - (cdr ends))))) + (apply f (reverse indices)) + (do ((i (car starts) (1+ i))) + ((= i (car ends))) + (for-indices-in-range% + f + (cons i indices) + (cdr starts) + (cdr ends))))) (for-indices-in-range% f '() starts ends)) ;;;; array utilities -- 2.39.2