From 1abe20b1db146b7bdee2e7b0209819064d415e51 Mon Sep 17 00:00:00 2001 From: admin Date: Mon, 6 Nov 2023 11:16:45 +0900 Subject: [PATCH] Change generic for typed and unspecified arrays --- grad-tests.scm | 16 +++--- grad.scm | 139 ++++++++++++++++++++++++++----------------------- misc.scm | 10 ++-- 3 files changed, 88 insertions(+), 77 deletions(-) diff --git a/grad-tests.scm b/grad-tests.scm index fd26c0a..77b3ff0 100644 --- a/grad-tests.scm +++ b/grad-tests.scm @@ -59,9 +59,9 @@ (random-array-shape))) (define* (random-array #:optional shape) - (let ((a (apply make-array 0 (or shape (random-array-shape))))) - (array-map! a random:uniform) - a)) + (apply produce-typed-array + (lambda _ (random:uniform)) + v:*atype* (or shape (random-array-shape)))) (define (random-non-empty-array) "Random array of at least one element." @@ -135,7 +135,7 @@ x y) #t)))) -(define* (~ x y #:optional (error 1e-4)) +(define* (~ x y #:optional (error 5e-2)) (cond ((and (number? x) (number? y)) (n~ x y error)) @@ -143,7 +143,7 @@ (a~ x y error)) (else #f))) -(define* (ngrad f #:optional (axis 0) (step 1e-6)) +(define* (ngrad f #:optional (axis 0) (step 1e-4)) "Gradient using a numerical centered difference approximation." (define (axis-add xs dh . indices) "Add `dh' to the number or array at the given `axis' of `xs', @@ -175,11 +175,12 @@ and, when it's an array, at the given index." ((and (number? fxs) (array? x)) (apply - produce-array + produce-typed-array (lambda indices (/ (- (apply f (apply axis-add xs step indices)) (apply f (apply axis-add xs (- step) indices))) (* 2 step))) + v:*atype* (array-dimensions x))) ((and (array? fxs) (number? x)) @@ -191,8 +192,7 @@ and, when it's an array, at the given index." ((and (array? fxs) (array? x)) (let ((a (apply - make-array - 0 + make-typed-array v:*atype* *unspecified* (append (array-dimensions fxs) (array-dimensions x))))) (for-indices-in-range diff --git a/grad.scm b/grad.scm index 81484a0..315f4c0 100644 --- a/grad.scm +++ b/grad.scm @@ -4,7 +4,8 @@ #:use-module (srfi srfi-9) #:use-module (vouivre misc) #:export - (adot + (*atype* + adot amap2 differentiable-wrapper dot @@ -39,7 +40,7 @@ (nb-dims-a (array-rank a)) (nb-dims-b (array-rank b)) (nb-fix-dims-a (- nb-dims-a n)) - (r (apply make-array 0 dims))) + (r (apply make-typed-array *atype* *unspecified* dims))) (for-indices-in-range (lambda r-indices (apply @@ -80,9 +81,10 @@ element-wise. All arrays must have the same dimension." (lambda xs (if-let (x (find array? xs)) (apply - produce-array + produce-typed-array (lambda is (apply-elemwise f is xs)) + *atype* (array-dimensions x)) (apply f xs)))) @@ -103,23 +105,23 @@ element-wise. All arrays must have the same dimension." 0 (array-rank x))) -(define (zeros-out:in x y) - "Zeros in the shape of [x]:[y]." - (cond - ((and (number? x) - (number? y)) - 0) - ((and (array? x) - (array? y)) - (apply make-array 0 (append (array-dimensions x) - (array-dimensions y)))) - ((and (number? x) - (array? y)) - (apply make-array 0 (array-dimensions y))) - ((and (array? x) - (number? y)) - (apply make-array 0 (array-dimensions x))) - (else (error "undefined." x y)))) +;; (define (zeros-out:in x y) +;; "Zeros in the shape of [x]:[y]." +;; (cond +;; ((and (number? x) +;; (number? y)) +;; 0) +;; ((and (array? x) +;; (array? y)) +;; (apply make-array 0 (append (array-dimensions x) +;; (array-dimensions y)))) +;; ((and (number? x) +;; (array? y)) +;; (apply make-array 0 (array-dimensions y))) +;; ((and (array? x) +;; (number? y)) +;; (apply make-array 0 (array-dimensions x))) +;; (else (error "undefined." x y)))) ;;;; differentiation @@ -129,6 +131,7 @@ element-wise. All arrays must have the same dimension." (forward internal-forward) (jacobian internal-jacobian)) +(define *atype* 'f32) (define *grad* (make-parameter #f)) (define* (grad f #:optional (axis 0)) @@ -148,15 +151,16 @@ element-wise. All arrays must have the same dimension." ((and (number? fx) (number? y)) (gfx '())) ((and (number? fx) (array? y)) - (apply produce-array + (apply produce-typed-array (lambda js (gfx js)) + *atype* (array-dimensions y))) ((and (array? fx) (number? y)) (gfx '())) (else (let* ((dims-out (array-dimensions fx)) (dims-in (array-dimensions y)) - (a (apply make-array 0 (append dims-out dims-in)))) + (a (apply make-typed-array *atype* *unspecified* (append dims-out dims-in)))) (for-indices-in-range (lambda js (let ((out (gfx js))) @@ -171,8 +175,7 @@ element-wise. All arrays must have the same dimension." dims-out))) (list-zeros (array-rank y)) dims-in) - a) - ))))) + a)))))) (parameterize (((@@ (vouivre grad) *grad*) y)) (apply f wrapped-xs)))))) @@ -223,10 +226,11 @@ element-wise. All arrays must have the same dimension." (apply jacobian-generator (append naked-inputs js)) (apply - produce-array + produce-typed-array (lambda is (apply jacobian-generator (append naked-inputs is js))) + *atype* (array-dimensions output))) (let ((b ((internal-jacobian input) js)) (fwd (internal-forward input))) @@ -238,39 +242,43 @@ element-wise. All arrays must have the same dimension." (array-ref (contract-arrays (apply - produce-array + produce-typed-array (lambda ks (apply jacobian-generator (append naked-inputs ks))) + *atype* (array-dimensions fwd)) b (array-rank fwd)))) ((and (array? output) (number? fwd)) (apply - produce-array + produce-typed-array (lambda is ((extend *) (apply jacobian-generator (append naked-inputs is)) b)) + *atype* (array-dimensions output))) (else (apply - produce-array + produce-typed-array (lambda is (array-ref (contract-arrays (apply - produce-array + produce-typed-array (lambda ks (apply jacobian-generator (append naked-inputs is ks))) + *atype* (array-dimensions fwd)) b (array-rank fwd)))) + *atype* (array-dimensions output))))))))) #f jacobian-generators inputs) (if (number? output) 0 - (apply make-array 0 (array-dimensions output))))))))) + (apply make-typed-array *atype* 0 (array-dimensions output))))))))) (define (ewise1 f) (lambda (x . indices) @@ -374,47 +382,41 @@ element-wise. All arrays must have the same dimension." (/ sum count))) x)) -;; ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30)) -;; (let ((x #2((1 2) (3 4) (5 6))) (y #2((1 2) (3 4) (5 6)))) ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))) (define (amap2 f x y) (define (unbox-with proc x) (ifn (internal? x) x (proc x))) -(define (dims-of x) + (define (dims-of x) (if (number? x) '() (array-dimensions x))) (define (boxed-ref x i) - (ifn (internal? x) - (array-cell-ref x i) - (make-internal (array-cell-ref (internal-forward x) - i) - (if (internal-jacobian x) - (lambda (js) - (array-cell-ref - ((internal-jacobian x) - js) - i)) - (lambda (js) - (let* ((x (internal-forward x)) - (xi (array-cell-ref x i))) - (if (number? xi) - (if (= i (car js)) - 1 - 0) - (let ((a (apply make-array 0 (dims-of xi)))) - (for-indices-in-range - (lambda out - (when (and (= i (car js)) - (equal? out (cdr js))) - (apply - array-set! - a 1 - out))) - (list-zeros (rank-of xi)) - (dims-of xi)) - a)))))))) + (let ((xi (array-cell-ref (unbox-with internal-forward x) + i))) + (ifn (internal? x) + xi + (make-internal + xi + (if (internal-jacobian x) + (lambda (js) + (array-cell-ref + ((internal-jacobian x) + js) + i)) + (lambda (js) + (if (number? xi) + (if (= i (car js)) + 1 + 0) + (apply + produce-typed-array + (lambda indices + (if (and (= i (car js)) + (equal? indices (cdr js))) + 1 + 0)) + *atype* (array-dimensions xi))))))))) (define (boxed-fi f i x y) (f (boxed-ref x i) (boxed-ref y i))) @@ -423,14 +425,17 @@ element-wise. All arrays must have the same dimension." (let ((bs (first (array-dimensions (unbox-with internal-forward x)))) (f0 (boxed-fi f 0 x y))) (ifn (internal? f0) - (let ((fwd (apply make-array 0 bs (dims-of f0)))) + (let ((fwd (apply make-typed-array + *atype* *unspecified* bs (dims-of f0)))) (for-indices-in-range (lambda (i) (array-cell-set! fwd (boxed-fi f i x y) i)) '(0) (list bs)) fwd) - (let ((jac (make-array 0 bs)) - (fwd (apply make-array 0 bs (dims-of (internal-forward f0))))) + (let ((jac (make-array *unspecified* bs)) + (fwd (apply make-typed-array + *atype* *unspecified* + bs (dims-of (internal-forward f0))))) (for-indices-in-range (lambda (i) (let* ((fi (boxed-fi f i x y)) @@ -441,7 +446,9 @@ element-wise. All arrays must have the same dimension." (make-internal fwd (lambda (js) - (let ((a (apply make-array 0 bs (dims-of (internal-forward f0))))) + (let ((a (apply make-typed-array + *atype* *unspecified* + bs (dims-of (internal-forward f0))))) (for-indices-in-range (lambda (batch) (array-cell-set! diff --git a/misc.scm b/misc.scm index 72c991b..0437d7c 100644 --- a/misc.scm +++ b/misc.scm @@ -10,7 +10,8 @@ ifn list-zeros map-indexed - produce-array)) + produce-array + produce-typed-array)) (define (flip f) "Returns a procedure behaving as `f', but with arguments taken in reverse @@ -57,11 +58,14 @@ order." ;;;; array utilities -(define (produce-array f . dims) - (let ((a (apply make-array 0 dims))) +(define (produce-typed-array f type . dims) + (let ((a (apply make-typed-array type *unspecified* dims))) (array-index-map! a f) a)) +(define (produce-array f . dims) + (apply produce-typed-array f #t dims)) + (define (array-map proc array . more) (let ((x (array-copy array))) (apply array-map! x proc array more) -- 2.39.5