]> git.vouivredigital.com Git - vouivre.git/commitdiff
Cleanup and document
authoradmin <admin@vouivredigital.com>
Mon, 27 Nov 2023 18:31:14 +0000 (03:31 +0900)
committeradmin <admin@vouivredigital.com>
Mon, 27 Nov 2023 18:31:14 +0000 (03:31 +0900)
vouivre/autodiff.scm

index 4720c9fa39502de30f632eade127a19ffb2155ef..a6ba7842939909ebffe4cdc1af3c0d72110756f4 100644 (file)
   #:export
   (adot
    amap2
-   contract-arrays
    differentiable-wrapper
-   dot
-   do-times
    ewise1
    ewise2
    extend
    fdiff
    logsumexp
    make-batch
-   make-internal
    maximum
    mean
-   rank-of
    rdiff
    relu
    sum)
@@ -135,23 +130,6 @@ element-wise. All arrays must have the same dimension."
             (array-dimensions x))
            (apply f xs))))
 
-(define (dot x y n)
-  (cond
-   ((and (number? x) (number? y))
-    (* x y))
-   ((and (array? x) (array? y))
-    (contract-arrays x y n))
-   ((and (array? x) (number? y))
-    ((extend *) x y))
-   ((and (number? x) (array? y))
-    ((extend *) x y))
-   (else (error "can't dot because of invalid types or ranks" x y n))))
-
-(define (rank-of x)
-  (if (number? x)
-      0
-      (array-rank x)))
-
 ;;;; differentiation
 
 (define-record-type internal
@@ -737,7 +715,7 @@ adding the result to the destination buffer."
    x))
 
 (define (i:array-ref x . indices)
-  "Differentiable array-ref w.r.t `x'."
+  "Differentiable `array-ref' w.r.t the first argument."
   (apply
    differentiable-wrapper
    (cons
@@ -753,6 +731,7 @@ adding the result to the destination buffer."
    x indices))
 
 (define (i:array-cell-ref x . indices)
+  "Differentiable `array-cell-ref' w.r.t the first argument."
   (apply
    differentiable-wrapper
    (cons
@@ -774,6 +753,7 @@ adding the result to the destination buffer."
    x indices))
 
 (define (make-batch elem . more)
+  "Differentiable function to batch one or more arrays together."
   (let ((batch-size (1+ (length more))))
     (apply
      differentiable-wrapper
@@ -838,6 +818,7 @@ adding the result to the destination buffer."
    x))
 
 (define (adot x y n)
+  "Differentiable array dot product."
   (differentiable-wrapper
    (list
     (lambda (xs i j n-free-dims-y n-bound-dims)
@@ -866,6 +847,8 @@ adding the result to the destination buffer."
    x y n))
 
 (define (amap2 f x y)
+  "Differentiable functional mapping on corresponding rows of two arrays with
+rank > 0."
   (apply make-batch
         (list-tabulate (car (dims-of (unwrap-fwd x)))
                        (lambda (b)
@@ -873,8 +856,10 @@ adding the result to the destination buffer."
                             (i:array-cell-ref y b))))))
 
 (define (relu x)
+  "Differentiable rectified linear unit."
   (i:max 0 x))
 
 (define (logsumexp x)
+  "Differentiable RealSoftMax using the log-sum-exp trick."
   (let ((c (maximum x)))
     (i:+ c (i:log (sum (i:exp (i:- x c)))))))