From b620ae3caea865ebe92fa8f883eb636fe4af6609 Mon Sep 17 00:00:00 2001 From: admin Date: Thu, 9 Nov 2023 12:46:12 +0900 Subject: [PATCH] Add more differentiable functions --- grad.scm | 320 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 241 insertions(+), 79 deletions(-) diff --git a/grad.scm b/grad.scm index 5e1ea7f..46b1a99 100644 --- a/grad.scm +++ b/grad.scm @@ -12,15 +12,31 @@ extend grad internal-jacobian + maximum mean - rank-of) + rank-of + sum) #:replace - ((i:* . *) - (i:- . -) + ((i:sqrt . sqrt) (i:exp . exp) - (i:fold . fold) + (i:expt . expt) + (i:log . log) + (i:sin . sin) + (i:cos . cos) + (i:tan . tan) + (i:+ . +) + (i:- . -) + (i:* . *) + (i:/ . /) + (i:max . max) + (i:min . min) + (i:abs . abs) (i:identity . identity) - (i:max . max))) + (i:array-ref . array-ref) + ) + #:re-export + (fold + reduce)) ;;;; array utilities @@ -105,24 +121,6 @@ element-wise. All arrays must have the same dimension." 0 (array-rank x))) -;; (define (zeros-out:in x y) -;; "Zeros in the shape of [x]:[y]." -;; (cond -;; ((and (number? x) -;; (number? y)) -;; 0) -;; ((and (array? x) -;; (array? y)) -;; (apply make-array 0 (append (array-dimensions x) -;; (array-dimensions y)))) -;; ((and (number? x) -;; (array? y)) -;; (apply make-array 0 (array-dimensions y))) -;; ((and (array? x) -;; (number? y)) -;; (apply make-array 0 (array-dimensions x))) -;; (else (error "undefined." x y)))) - ;;;; differentiation (define-record-type internal @@ -131,7 +129,8 @@ element-wise. All arrays must have the same dimension." (forward internal-forward) (jacobian internal-jacobian)) -(define *atype* 'f32) +;; (define *atype* 'f64) +(define *atype* #t) (define *grad* (make-parameter #f)) (define* (grad f #:optional (axis 0)) @@ -223,20 +222,18 @@ element-wise. All arrays must have the same dimension." ((ifn prev identity (lambda (x) ((extend +) prev x))) (ifn (internal-jacobian input) (if (number? output) - (apply jacobian-generator - (append naked-inputs js)) + (jacobian-generator naked-inputs js) (apply produce-typed-array (lambda is - (apply jacobian-generator - (append naked-inputs is js))) + (jacobian-generator naked-inputs (append is js))) *atype* (array-dimensions output))) (let ((b ((internal-jacobian input) js)) (fwd (internal-forward input))) (cond ((and (number? output) (number? fwd)) - (* (apply jacobian-generator naked-inputs) + (* (jacobian-generator naked-inputs '()) b)) ((and (number? output) (array? fwd)) (array-ref @@ -244,8 +241,7 @@ element-wise. All arrays must have the same dimension." (apply produce-typed-array (lambda ks - (apply jacobian-generator - (append naked-inputs ks))) + (jacobian-generator naked-inputs ks)) *atype* (array-dimensions fwd)) b (array-rank fwd)))) @@ -254,8 +250,7 @@ element-wise. All arrays must have the same dimension." produce-typed-array (lambda is ((extend *) - (apply jacobian-generator - (append naked-inputs is)) + (jacobian-generator naked-inputs is) b)) *atype* (array-dimensions output))) @@ -268,8 +263,7 @@ element-wise. All arrays must have the same dimension." (apply produce-typed-array (lambda ks - (apply jacobian-generator - (append naked-inputs is ks))) + (jacobian-generator naked-inputs (append is ks))) *atype* (array-dimensions fwd)) b (array-rank fwd)))) @@ -281,40 +275,43 @@ element-wise. All arrays must have the same dimension." (apply make-typed-array *atype* 0 (array-dimensions output))))))))) (define (ewise1 f) - (lambda (x . indices) - (if (number? x) - (f x) - (receive (is js) (split-at indices (array-rank x)) - (if (equal? is js) - (f (apply array-ref x is)) - 0))))) + (lambda (xs indices) + (let ((x (car xs))) + (if (number? x) + (f x) + (receive (is js) (split-at indices (array-rank x)) + (if (equal? is js) + (f (apply array-ref x is)) + 0)))))) (define (ewise2 proc axis) - (lambda (x y . indices) - (cond - ((and (number? x) (number? y)) - (proc x y)) - ((and (number? x) (array? y)) - (if (= axis 0) - (proc x (apply array-ref y indices)) - (receive (is js) (split-at indices (array-rank y)) - (ifn (equal? is js) - 0 - (proc x (apply array-ref y is)))))) - ((and (array? x) (number? y)) - (if (= axis 1) - (proc (apply array-ref x indices) y) - (receive (is js) (split-at indices (array-rank x)) - (ifn (equal? is js) - 0 - (proc (apply array-ref x is) - y))))) - (else - (receive (is js) (split-at indices (array-rank x)) - (ifn (equal? is js) - 0 - (proc (apply array-ref x is) - (apply array-ref y is)))))))) + (lambda (xs indices) + (let ((x (car xs)) + (y (cadr xs))) + (cond + ((and (number? x) (number? y)) + (proc x y)) + ((and (number? x) (array? y)) + (if (= axis 0) + (proc x (apply array-ref y indices)) + (receive (is js) (split-at indices (array-rank y)) + (ifn (equal? is js) + 0 + (proc x (apply array-ref y is)))))) + ((and (array? x) (number? y)) + (if (= axis 1) + (proc (apply array-ref x indices) y) + (receive (is js) (split-at indices (array-rank x)) + (ifn (equal? is js) + 0 + (proc (apply array-ref x is) + y))))) + (else + (receive (is js) (split-at indices (array-rank x)) + (ifn (equal? is js) + 0 + (proc (apply array-ref x is) + (apply array-ref y is))))))))) (define (i:identity x) "Differentiable identity." @@ -323,6 +320,13 @@ element-wise. All arrays must have the same dimension." identity x)) +(define (i:sqrt x) + "Differentiable square root." + (differentiable-wrapper + (list (ewise1 (lambda (x) (/ 1 2 (sqrt x))))) + (extend sqrt) + x)) + (define (i:exp x) "Differentiable exponential." (differentiable-wrapper @@ -330,13 +334,49 @@ element-wise. All arrays must have the same dimension." (extend exp) x)) -(define (i:* x y) - "Differentiable element-wise multiplication." +(define (i:expt x y) + "Differentiable power." + (differentiable-wrapper + (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0) + (ewise2 (lambda (x y) (* (expt x y) (log x))) 1)) + (extend expt) + x y)) + +(define (i:log x) + "Differentiable logarithm." + (differentiable-wrapper + (list (ewise1 (lambda (x) (/ x)))) + (extend log) + x)) + +(define (i:sin x) + "Differentiable sine." + (differentiable-wrapper + (list (ewise1 cons)) + (extend sin) + x)) + +(define (i:cos x) + "Differentiable cosine." + (differentiable-wrapper + (list (ewise1 (lambda (x) (- (sin x))))) + (extend cos) + x)) + +(define (i:tan x) + "Differentiable tangent." + (differentiable-wrapper + (list (ewise1 (lambda (x) (/ (expt (cos x) 2))))) + (extend tan) + x)) + +(define (i:+ x y) + "Differentiable element-wise addition." (differentiable-wrapper (list - (ewise2 (lambda (x y) y) 0) - (ewise2 (lambda (x y) x) 1)) - (extend *) + (ewise2 (lambda _ +1) 0) + (ewise2 (lambda _ +1) 1)) + (extend +) x y)) (define (i:- x y) @@ -348,6 +388,24 @@ element-wise. All arrays must have the same dimension." (extend -) x y)) +(define (i:* x y) + "Differentiable element-wise multiplication." + (differentiable-wrapper + (list + (ewise2 (lambda (x y) y) 0) + (ewise2 (lambda (x y) x) 1)) + (extend *) + x y)) + +(define (i:/ x y) + "Differentiable element-wise division." + (differentiable-wrapper + (list + (ewise2 (lambda (x y) (/ y)) 0) + (ewise2 (lambda (x y) (- (/ x y y))) 1)) + (extend /) + x y)) + (define (i:max x y) "Differentiable element-wise maximum." (define (dmax x y) @@ -365,12 +423,42 @@ element-wise. All arrays must have the same dimension." (extend max) x y)) +(define (i:min x y) + "Differentiable element-wise minimum." + (define (dmin x y) + (cond + ((< x y) + 1) + ((= x y) + 1/2) + (else + 0))) + (differentiable-wrapper + (list + (ewise2 dmin 0) + (ewise2 (flip dmin) 1)) + (extend min) + x y)) + +(define (i:abs x) + "Differentiable absolute." + (differentiable-wrapper + (list (ewise1 (lambda (x) + (cond ((> x 0) + +1) + ((= x 0) + 1/2) + ((< x 0) + -1))))) + (extend abs) + x)) + (define (mean x) "Differentiable mean on arrays." (differentiable-wrapper (list - (lambda (x . indices) - (/ 1 (apply * (array-dimensions x))))) + (lambda (xs indices) + (/ 1 (apply * (array-dimensions (car xs)))))) (lambda (x) (let ((sum 0) (count 0)) @@ -382,6 +470,70 @@ element-wise. All arrays must have the same dimension." (/ sum count))) x)) +(define (i:array-ref x . indices) + "Differentiable array-ref w.r.t `x'." + (apply + differentiable-wrapper + (cons + (lambda (xs js) + (if (equal? js indices) + 1 + 0)) + (list-tabulate (length indices) not)) + array-ref + x indices)) + +(define (for-flat-index proc a) + (let* ((b (array-contents a)) + (n (array-length b))) + (let rec ((i 0)) + (unless (= i n) + (proc b i) + (rec (1+ i)))))) + +(define (maximum x) + "Differentiable maximum on arrays." + (differentiable-wrapper + (list + (lambda (xs indices) + (let ((x (car xs)) + (m (- (inf))) + (i 'TBD)) + (for-indices-in-range + (lambda indices + (let ((xi (apply array-ref x indices))) + (when (< m xi) + (set! m xi) + (set! i indices)))) + (list-zeros (array-rank x)) + (array-dimensions x)) + (if (equal? i indices) + 1 + 0)))) + (lambda (x) + (let ((m (- (inf)))) + (array-for-each + (lambda (x) + (set! m (max m x))) + x) + m)) + x)) + +(define (sum x) + "Differentiable sum on arrays." + (differentiable-wrapper + (list + (lambda (xs indices) + 1)) + (lambda (x) + (let ((sum 0)) + (array-for-each + (lambda (x) + (set! sum (+ sum x))) + x) + sum)) + x)) + (define (amap2 f x y) (define (unbox-with proc x) (ifn (internal? x) @@ -416,7 +568,8 @@ element-wise. All arrays must have the same dimension." (equal? indices (cdr js))) 1 0)) - *atype* (array-dimensions xi))))))))) + (array-type xi) + (array-dimensions xi))))))))) (define (boxed-fi f i x y) (f (boxed-ref x i) (boxed-ref y i))) @@ -426,6 +579,7 @@ element-wise. All arrays must have the same dimension." (f0 (boxed-fi f 0 x y))) (ifn (internal? f0) (let ((fwd (apply make-typed-array + ;; TODO: use the correct type based on f0. *atype* *unspecified* bs (dims-of f0)))) (for-indices-in-range (lambda (i) @@ -434,6 +588,7 @@ element-wise. All arrays must have the same dimension." fwd) (let ((jac (make-array *unspecified* bs)) (fwd (apply make-typed-array + ;; TODO: use the correct type based on f0. *atype* *unspecified* bs (dims-of (internal-forward f0))))) (for-indices-in-range @@ -447,6 +602,7 @@ element-wise. All arrays must have the same dimension." fwd (lambda (js) (let ((a (apply make-typed-array + ;; TODO: use the correct type based on f0. *atype* *unspecified* bs (dims-of (internal-forward f0))))) (for-indices-in-range @@ -460,8 +616,11 @@ element-wise. All arrays must have the same dimension." (define (adot x y n) (differentiable-wrapper (list - (lambda (x y n . indices) - (let* ((free-rank-x (- (array-rank x) + (lambda (xs indices) + (let* ((x (car xs)) + (y (cadr xs)) + (n (caddr xs)) + (free-rank-x (- (array-rank x) n)) (free-rank-y (- (array-rank y) n)) @@ -474,8 +633,11 @@ element-wise. All arrays must have the same dimension." (ifn (equal? is-free1 is-free2) 0 (apply array-ref y (append is-bound js-free))))) - (lambda (x y n . indices) - (let* ((free-rank-x (- (array-rank x) + (lambda (xs indices) + (let* ((x (car xs)) + (y (cadr xs)) + (n (caddr xs)) + (free-rank-x (- (array-rank x) n)) (free-rank-y (- (array-rank y) n)) -- 2.39.2