]> git.vouivredigital.com Git - vouivre.git/commitdiff
Add more differentiable functions
authoradmin <admin@vouivredigital.com>
Thu, 9 Nov 2023 03:46:12 +0000 (12:46 +0900)
committeradmin <admin@vouivredigital.com>
Thu, 9 Nov 2023 03:46:12 +0000 (12:46 +0900)
grad.scm

index 5e1ea7fec84e4a992e5c10059ad0b30793e6e226..46b1a999b20c747417753bb807292cc774c930f6 100644 (file)
--- a/grad.scm
+++ b/grad.scm
    extend
    grad
    internal-jacobian
+   maximum
    mean
-   rank-of)
+   rank-of
+   sum)
   #:replace
-  ((i:* . *)
-   (i:- . -)
+  ((i:sqrt . sqrt)
    (i:exp . exp)
-   (i:fold . fold)
+   (i:expt . expt)
+   (i:log . log)
+   (i:sin . sin)
+   (i:cos . cos)
+   (i:tan . tan)
+   (i:+ . +)
+   (i:- . -)
+   (i:* . *)
+   (i:/ . /)
+   (i:max . max)
+   (i:min . min)
+   (i:abs . abs)
    (i:identity . identity)
-   (i:max . max)))
+   (i:array-ref . array-ref)
+   )
+  #:re-export
+  (fold
+   reduce))
 
 ;;;; array utilities
 
@@ -105,24 +121,6 @@ element-wise. All arrays must have the same dimension."
       0
       (array-rank x)))
 
-;; (define (zeros-out:in x y)
-;;   "Zeros in the shape of [x]:[y]."
-;;   (cond
-;;    ((and (number? x)
-;;      (number? y))
-;;     0)
-;;    ((and (array? x)
-;;      (array? y))
-;;     (apply make-array 0 (append (array-dimensions x)
-;;                             (array-dimensions y))))
-;;    ((and (number? x)
-;;      (array? y))
-;;     (apply make-array 0 (array-dimensions y)))
-;;    ((and (array? x)
-;;      (number? y))
-;;     (apply make-array 0 (array-dimensions x)))
-;;    (else (error "undefined." x y))))
-
 ;;;; differentiation
 
 (define-record-type internal
@@ -131,7 +129,8 @@ element-wise. All arrays must have the same dimension."
   (forward internal-forward)
   (jacobian internal-jacobian))
 
-(define *atype* 'f32)
+;; (define *atype* 'f64)
+(define *atype* #t)
 (define *grad* (make-parameter #f))
 
 (define* (grad f #:optional (axis 0))
@@ -223,20 +222,18 @@ element-wise. All arrays must have the same dimension."
                     ((ifn prev identity (lambda (x) ((extend +) prev x)))
                      (ifn (internal-jacobian input)
                           (if (number? output)
-                              (apply jacobian-generator
-                                     (append naked-inputs js))
+                              (jacobian-generator naked-inputs js)
                               (apply
                                produce-typed-array
                                (lambda is
-                                 (apply jacobian-generator
-                                        (append naked-inputs is js)))
+                                 (jacobian-generator naked-inputs (append is js)))
                                *atype*
                                (array-dimensions output)))
                           (let ((b ((internal-jacobian input) js))
                                 (fwd (internal-forward input)))
                             (cond
                              ((and (number? output) (number? fwd))
-                              (* (apply jacobian-generator naked-inputs)
+                              (* (jacobian-generator naked-inputs '())
                                  b))
                              ((and (number? output) (array? fwd))
                               (array-ref
@@ -244,8 +241,7 @@ element-wise. All arrays must have the same dimension."
                                 (apply
                                  produce-typed-array
                                  (lambda ks
-                                   (apply jacobian-generator
-                                          (append naked-inputs ks)))
+                                   (jacobian-generator naked-inputs ks))
                                  *atype*
                                  (array-dimensions fwd))
                                 b (array-rank fwd))))
@@ -254,8 +250,7 @@ element-wise. All arrays must have the same dimension."
                                produce-typed-array
                                (lambda is
                                  ((extend *)
-                                  (apply jacobian-generator
-                                         (append naked-inputs is))
+                                  (jacobian-generator naked-inputs is)
                                   b))
                                *atype*
                                (array-dimensions output)))
@@ -268,8 +263,7 @@ element-wise. All arrays must have the same dimension."
                                    (apply
                                     produce-typed-array
                                     (lambda ks
-                                      (apply jacobian-generator
-                                             (append naked-inputs is ks)))
+                                      (jacobian-generator naked-inputs (append is ks)))
                                     *atype*
                                     (array-dimensions fwd))
                                    b (array-rank fwd))))
@@ -281,40 +275,43 @@ element-wise. All arrays must have the same dimension."
                 (apply make-typed-array *atype* 0 (array-dimensions output)))))))))
 
 (define (ewise1 f)
-  (lambda (x . indices)
-    (if (number? x)
-       (f x)
-       (receive (is js) (split-at indices (array-rank x))
-         (if (equal? is js)
-             (f (apply array-ref x is))
-             0)))))
+  (lambda (xs indices)
+    (let ((x (car xs)))
+      (if (number? x)
+         (f x)
+         (receive (is js) (split-at indices (array-rank x))
+           (if (equal? is js)
+               (f (apply array-ref x is))
+               0))))))
 
 (define (ewise2 proc axis)
-  (lambda (x y . indices)
-    (cond
-     ((and (number? x) (number? y))
-      (proc x y))
-     ((and (number? x) (array? y))
-      (if (= axis 0)
-         (proc x (apply array-ref y indices))
-         (receive (is js) (split-at indices (array-rank y))
-           (ifn (equal? is js)
-                0
-                (proc x (apply array-ref y is))))))
-     ((and (array? x) (number? y))
-      (if (= axis 1)
-         (proc (apply array-ref x indices) y)
-         (receive (is js) (split-at indices (array-rank x))
-           (ifn (equal? is js)
-                0
-                (proc (apply array-ref x is)
-                      y)))))
-     (else
-      (receive (is js) (split-at indices (array-rank x))
-       (ifn (equal? is js)
-            0
-            (proc (apply array-ref x is)
-                  (apply array-ref y is))))))))
+  (lambda (xs indices)
+    (let ((x (car xs))
+         (y (cadr xs)))
+      (cond
+       ((and (number? x) (number? y))
+       (proc x y))
+       ((and (number? x) (array? y))
+       (if (= axis 0)
+           (proc x (apply array-ref y indices))
+           (receive (is js) (split-at indices (array-rank y))
+             (ifn (equal? is js)
+                  0
+                  (proc x (apply array-ref y is))))))
+       ((and (array? x) (number? y))
+       (if (= axis 1)
+           (proc (apply array-ref x indices) y)
+           (receive (is js) (split-at indices (array-rank x))
+             (ifn (equal? is js)
+                  0
+                  (proc (apply array-ref x is)
+                        y)))))
+       (else
+       (receive (is js) (split-at indices (array-rank x))
+         (ifn (equal? is js)
+              0
+              (proc (apply array-ref x is)
+                    (apply array-ref y is)))))))))
 
 (define (i:identity x)
   "Differentiable identity."
@@ -323,6 +320,13 @@ element-wise. All arrays must have the same dimension."
    identity
    x))
 
+(define (i:sqrt x)
+  "Differentiable square root."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (/ 1 2 (sqrt x)))))
+   (extend sqrt)
+   x))
+
 (define (i:exp x)
   "Differentiable exponential."
   (differentiable-wrapper
@@ -330,13 +334,49 @@ element-wise. All arrays must have the same dimension."
    (extend exp)
    x))
 
-(define (i:* x y)
-  "Differentiable element-wise multiplication."
+(define (i:expt x y)
+  "Differentiable power."
+  (differentiable-wrapper
+   (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0)
+        (ewise2 (lambda (x y) (* (expt x y) (log x))) 1))
+   (extend expt)
+   x y))
+
+(define (i:log x)
+  "Differentiable logarithm."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (/ x))))
+   (extend log)
+   x))
+
+(define (i:sin x)
+  "Differentiable sine."
+  (differentiable-wrapper
+   (list (ewise1 cons))
+   (extend sin)
+   x))
+
+(define (i:cos x)
+  "Differentiable cosine."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (- (sin x)))))
+   (extend cos)
+   x))
+
+(define (i:tan x)
+  "Differentiable tangent."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x) (/ (expt (cos x) 2)))))
+   (extend tan)
+   x))
+
+(define (i:+ x y)
+  "Differentiable element-wise addition."
   (differentiable-wrapper
    (list
-    (ewise2 (lambda (x y) y) 0)
-    (ewise2 (lambda (x y) x) 1))
-   (extend *)
+    (ewise2 (lambda _ +1) 0)
+    (ewise2 (lambda _ +1) 1))
+   (extend +)
    x y))
 
 (define (i:- x y)
@@ -348,6 +388,24 @@ element-wise. All arrays must have the same dimension."
    (extend -)
    x y))
 
+(define (i:* x y)
+  "Differentiable element-wise multiplication."
+  (differentiable-wrapper
+   (list
+    (ewise2 (lambda (x y) y) 0)
+    (ewise2 (lambda (x y) x) 1))
+   (extend *)
+   x y))
+
+(define (i:/ x y)
+  "Differentiable element-wise division."
+  (differentiable-wrapper
+   (list
+    (ewise2 (lambda (x y) (/ y)) 0)
+    (ewise2 (lambda (x y) (- (/ x y y))) 1))
+   (extend /)
+   x y))
+
 (define (i:max x y)
   "Differentiable element-wise maximum."
   (define (dmax x y)
@@ -365,12 +423,42 @@ element-wise. All arrays must have the same dimension."
    (extend max)
    x y))
 
+(define (i:min x y)
+  "Differentiable element-wise minimum."
+  (define (dmin x y)
+    (cond
+     ((< x y)
+      1)
+     ((= x y)
+      1/2)
+     (else
+      0)))
+  (differentiable-wrapper
+   (list
+    (ewise2 dmin 0)
+    (ewise2 (flip dmin) 1))
+   (extend min)
+   x y))
+
+(define (i:abs x)
+  "Differentiable absolute."
+  (differentiable-wrapper
+   (list (ewise1 (lambda (x)
+                  (cond ((> x 0)
+                         +1)
+                        ((= x 0)
+                         1/2)
+                        ((< x 0)
+                         -1)))))
+   (extend abs)
+   x))
+
 (define (mean x)
   "Differentiable mean on arrays."
   (differentiable-wrapper
    (list
-    (lambda (x . indices)
-      (/ 1 (apply * (array-dimensions x)))))
+    (lambda (xs indices)
+      (/ 1 (apply * (array-dimensions (car xs))))))
    (lambda (x)
      (let ((sum 0)
           (count 0))
@@ -382,6 +470,70 @@ element-wise. All arrays must have the same dimension."
        (/ sum count)))
    x))
 
+(define (i:array-ref x . indices)
+  "Differentiable array-ref w.r.t `x'."
+  (apply
+   differentiable-wrapper
+   (cons
+    (lambda (xs js)
+      (if (equal? js indices)
+         1
+         0))
+    (list-tabulate (length indices) not))
+   array-ref
+   x indices))
+
+(define (for-flat-index proc a)
+  (let* ((b (array-contents a))
+        (n (array-length b)))
+    (let rec ((i 0))
+      (unless (= i n)
+       (proc b i)
+       (rec (1+ i))))))
+
+(define (maximum x)
+  "Differentiable maximum on arrays."
+  (differentiable-wrapper
+   (list
+    (lambda (xs indices)
+      (let ((x (car xs))
+           (m (- (inf)))
+           (i 'TBD))
+       (for-indices-in-range
+        (lambda indices
+          (let ((xi (apply array-ref x indices)))
+            (when (< m xi)
+              (set! m xi)
+              (set! i indices))))
+        (list-zeros (array-rank x))
+        (array-dimensions x))
+       (if (equal? i indices)
+           1
+           0))))
+   (lambda (x)
+     (let ((m (- (inf))))
+       (array-for-each
+       (lambda (x)
+         (set! m (max m x)))
+       x)
+       m))
+   x))
+
+(define (sum x)
+  "Differentiable sum on arrays."
+  (differentiable-wrapper
+   (list
+    (lambda (xs indices)
+      1))
+   (lambda (x)
+     (let ((sum 0))
+       (array-for-each
+       (lambda (x)
+         (set! sum (+ sum x)))
+       x)
+       sum))
+   x))
+
 (define (amap2 f x y)
   (define (unbox-with proc x)
     (ifn (internal? x)
@@ -416,7 +568,8 @@ element-wise. All arrays must have the same dimension."
                                  (equal? indices (cdr js)))
                             1
                             0))
-                      *atype* (array-dimensions xi)))))))))
+                      (array-type xi)
+                      (array-dimensions xi)))))))))
   (define (boxed-fi f i x y)
     (f (boxed-ref x i)
        (boxed-ref y i)))
@@ -426,6 +579,7 @@ element-wise. All arrays must have the same dimension."
            (f0 (boxed-fi f 0 x y)))
        (ifn (internal? f0)
             (let ((fwd (apply make-typed-array
+                              ;; TODO: use the correct type based on f0.
                               *atype* *unspecified* bs (dims-of f0))))
               (for-indices-in-range
                (lambda (i)
@@ -434,6 +588,7 @@ element-wise. All arrays must have the same dimension."
               fwd)
             (let ((jac (make-array *unspecified* bs))
                   (fwd (apply make-typed-array
+                              ;; TODO: use the correct type based on f0.
                               *atype* *unspecified*
                               bs (dims-of (internal-forward f0)))))
               (for-indices-in-range
@@ -447,6 +602,7 @@ element-wise. All arrays must have the same dimension."
                fwd
                (lambda (js)
                  (let ((a (apply make-typed-array
+                                 ;; TODO: use the correct type based on f0.
                                  *atype* *unspecified*
                                  bs (dims-of (internal-forward f0)))))
                    (for-indices-in-range
@@ -460,8 +616,11 @@ element-wise. All arrays must have the same dimension."
 (define (adot x y n)
   (differentiable-wrapper
    (list
-    (lambda (x y n . indices)
-      (let* ((free-rank-x (- (array-rank x)
+    (lambda (xs indices)
+      (let* ((x (car xs))
+            (y (cadr xs))
+            (n (caddr xs))
+            (free-rank-x (- (array-rank x)
                             n))
             (free-rank-y (- (array-rank y)
                             n))
@@ -474,8 +633,11 @@ element-wise. All arrays must have the same dimension."
        (ifn (equal? is-free1 is-free2)
             0
             (apply array-ref y (append is-bound js-free)))))
-    (lambda (x y n . indices)
-      (let* ((free-rank-x (- (array-rank x)
+    (lambda (xs indices)
+      (let* ((x (car xs))
+            (y (cadr xs))
+            (n (caddr xs))
+            (free-rank-x (- (array-rank x)
                             n))
             (free-rank-y (- (array-rank y)
                             n))