From 95e75d6a78363b3e2cbd0239d2af5e83cbad4746 Mon Sep 17 00:00:00 2001 From: admin Date: Tue, 28 Nov 2023 03:31:14 +0900 Subject: [PATCH] Cleanup and document --- vouivre/autodiff.scm | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/vouivre/autodiff.scm b/vouivre/autodiff.scm index 4720c9f..a6ba784 100644 --- a/vouivre/autodiff.scm +++ b/vouivre/autodiff.scm @@ -24,20 +24,15 @@ #:export (adot amap2 - contract-arrays differentiable-wrapper - dot - do-times ewise1 ewise2 extend fdiff logsumexp make-batch - make-internal maximum mean - rank-of rdiff relu sum) @@ -135,23 +130,6 @@ element-wise. All arrays must have the same dimension." (array-dimensions x)) (apply f xs)))) -(define (dot x y n) - (cond - ((and (number? x) (number? y)) - (* x y)) - ((and (array? x) (array? y)) - (contract-arrays x y n)) - ((and (array? x) (number? y)) - ((extend *) x y)) - ((and (number? x) (array? y)) - ((extend *) x y)) - (else (error "can't dot because of invalid types or ranks" x y n)))) - -(define (rank-of x) - (if (number? x) - 0 - (array-rank x))) - ;;;; differentiation (define-record-type internal @@ -737,7 +715,7 @@ adding the result to the destination buffer." x)) (define (i:array-ref x . indices) - "Differentiable array-ref w.r.t `x'." + "Differentiable `array-ref' w.r.t the first argument." (apply differentiable-wrapper (cons @@ -753,6 +731,7 @@ adding the result to the destination buffer." x indices)) (define (i:array-cell-ref x . indices) + "Differentiable `array-cell-ref' w.r.t the first argument." (apply differentiable-wrapper (cons @@ -774,6 +753,7 @@ adding the result to the destination buffer." x indices)) (define (make-batch elem . more) + "Differentiable function to batch one or more arrays together." (let ((batch-size (1+ (length more)))) (apply differentiable-wrapper @@ -838,6 +818,7 @@ adding the result to the destination buffer." x)) (define (adot x y n) + "Differentiable array dot product." (differentiable-wrapper (list (lambda (xs i j n-free-dims-y n-bound-dims) @@ -866,6 +847,8 @@ adding the result to the destination buffer." x y n)) (define (amap2 f x y) + "Differentiable functional mapping on corresponding rows of two arrays with +rank > 0." (apply make-batch (list-tabulate (car (dims-of (unwrap-fwd x))) (lambda (b) @@ -873,8 +856,10 @@ adding the result to the destination buffer." (i:array-cell-ref y b)))))) (define (relu x) + "Differentiable rectified linear unit." (i:max 0 x)) (define (logsumexp x) + "Differentiable RealSoftMax using the log-sum-exp trick." (let ((c (maximum x))) (i:+ c (i:log (sum (i:exp (i:- x c))))))) -- 2.39.5