]> git.vouivredigital.com Git - vouivre.git/commitdiff
Change generic for typed and unspecified arrays
authoradmin <admin@vouivredigital.com>
Mon, 6 Nov 2023 02:16:45 +0000 (11:16 +0900)
committeradmin <admin@vouivredigital.com>
Mon, 6 Nov 2023 02:16:45 +0000 (11:16 +0900)
grad-tests.scm
grad.scm
misc.scm

index fd26c0a6fe0592dad18cd339b96f9a0038199901..77b3ff0afa1f15c3370a83833ddcc8e11a039d58 100644 (file)
@@ -59,9 +59,9 @@
       (random-array-shape)))
 
 (define* (random-array #:optional shape)
-  (let ((a (apply make-array 0 (or shape (random-array-shape)))))
-    (array-map! a random:uniform)
-    a))
+  (apply produce-typed-array
+        (lambda _ (random:uniform))
+        v:*atype* (or shape (random-array-shape))))
 
 (define (random-non-empty-array)
   "Random array of at least one element."
       x y)
      #t))))
 
-(define* (~ x y #:optional (error 1e-4))
+(define* (~ x y #:optional (error 5e-2))
   (cond
    ((and (number? x) (number? y))
     (n~ x y error))
     (a~ x y error))
    (else #f)))
 
-(define* (ngrad f #:optional (axis 0) (step 1e-6))
+(define* (ngrad f #:optional (axis 0) (step 1e-4))
   "Gradient using a numerical centered difference approximation."
   (define (axis-add xs dh . indices)
     "Add `dh' to the number or array at the given `axis' of `xs',
@@ -175,11 +175,12 @@ and, when it's an array, at the given index."
        ((and (number? fxs)
             (array? x))
        (apply
-        produce-array
+        produce-typed-array
         (lambda indices
           (/ (- (apply f (apply axis-add xs step indices))
                 (apply f (apply axis-add xs (- step) indices)))
              (* 2 step)))
+        v:*atype*
         (array-dimensions x)))
        ((and (array? fxs)
             (number? x))
@@ -191,8 +192,7 @@ and, when it's an array, at the given index."
        ((and (array? fxs)
             (array? x))
        (let ((a (apply
-                 make-array
-                 0
+                 make-typed-array v:*atype* *unspecified*
                  (append (array-dimensions fxs)
                          (array-dimensions x)))))
          (for-indices-in-range
index 81484a08400dc38b9e59884fa6adee3e19dadd0f..315f4c0e61834ec8592abf26830e1036dbc44c63 100644 (file)
--- a/grad.scm
+++ b/grad.scm
@@ -4,7 +4,8 @@
   #:use-module (srfi srfi-9)
   #:use-module (vouivre misc)
   #:export
-  (adot
+  (*atype*
+   adot
    amap2
    differentiable-wrapper
    dot
@@ -39,7 +40,7 @@
          (nb-dims-a (array-rank a))
          (nb-dims-b (array-rank b))
          (nb-fix-dims-a (- nb-dims-a n))
-         (r (apply make-array 0 dims)))
+         (r (apply make-typed-array *atype* *unspecified* dims)))
     (for-indices-in-range
       (lambda r-indices
         (apply
@@ -80,9 +81,10 @@ element-wise. All arrays must have the same dimension."
   (lambda xs
     (if-let (x (find array? xs))
            (apply
-            produce-array
+            produce-typed-array
             (lambda is
               (apply-elemwise f is xs))
+            *atype*
             (array-dimensions x))
            (apply f xs))))
 
@@ -103,23 +105,23 @@ element-wise. All arrays must have the same dimension."
       0
       (array-rank x)))
 
-(define (zeros-out:in x y)
-  "Zeros in the shape of [x]:[y]."
-  (cond
-   ((and (number? x)
-        (number? y))
-    0)
-   ((and (array? x)
-        (array? y))
-    (apply make-array 0 (append (array-dimensions x)
-                               (array-dimensions y))))
-   ((and (number? x)
-        (array? y))
-    (apply make-array 0 (array-dimensions y)))
-   ((and (array? x)
-        (number? y))
-    (apply make-array 0 (array-dimensions x)))
-   (else (error "undefined." x y))))
+;; (define (zeros-out:in x y)
+;;   "Zeros in the shape of [x]:[y]."
+;;   (cond
+;;    ((and (number? x)
+;;      (number? y))
+;;     0)
+;;    ((and (array? x)
+;;      (array? y))
+;;     (apply make-array 0 (append (array-dimensions x)
+;;                             (array-dimensions y))))
+;;    ((and (number? x)
+;;      (array? y))
+;;     (apply make-array 0 (array-dimensions y)))
+;;    ((and (array? x)
+;;      (number? y))
+;;     (apply make-array 0 (array-dimensions x)))
+;;    (else (error "undefined." x y))))
 
 ;;;; differentiation
 
@@ -129,6 +131,7 @@ element-wise. All arrays must have the same dimension."
   (forward internal-forward)
   (jacobian internal-jacobian))
 
+(define *atype* 'f32)
 (define *grad* (make-parameter #f))
 
 (define* (grad f #:optional (axis 0))
@@ -148,15 +151,16 @@ element-wise. All arrays must have the same dimension."
                ((and (number? fx) (number? y))
                 (gfx '()))
                ((and (number? fx) (array? y))
-                (apply produce-array
+                (apply produce-typed-array
                        (lambda js (gfx js))
+                       *atype*
                        (array-dimensions y)))
                ((and (array? fx) (number? y))
                 (gfx '()))
                (else
                 (let* ((dims-out (array-dimensions fx))
                        (dims-in (array-dimensions y))
-                       (a (apply make-array 0 (append dims-out dims-in))))
+                       (a (apply make-typed-array *atype* *unspecified* (append dims-out dims-in))))
                   (for-indices-in-range
                    (lambda js
                      (let ((out (gfx js)))
@@ -171,8 +175,7 @@ element-wise. All arrays must have the same dimension."
                         dims-out)))
                    (list-zeros (array-rank y))
                    dims-in)
-                  a)
-                )))))
+                  a))))))
        (parameterize (((@@ (vouivre grad) *grad*) y))
         (apply f wrapped-xs))))))
 
@@ -223,10 +226,11 @@ element-wise. All arrays must have the same dimension."
                               (apply jacobian-generator
                                      (append naked-inputs js))
                               (apply
-                               produce-array
+                               produce-typed-array
                                (lambda is
                                  (apply jacobian-generator
                                         (append naked-inputs is js)))
+                               *atype*
                                (array-dimensions output)))
                           (let ((b ((internal-jacobian input) js))
                                 (fwd (internal-forward input)))
@@ -238,39 +242,43 @@ element-wise. All arrays must have the same dimension."
                               (array-ref
                                (contract-arrays
                                 (apply
-                                 produce-array
+                                 produce-typed-array
                                  (lambda ks
                                    (apply jacobian-generator
                                           (append naked-inputs ks)))
+                                 *atype*
                                  (array-dimensions fwd))
                                 b (array-rank fwd))))
                              ((and (array? output) (number? fwd))
                               (apply
-                               produce-array
+                               produce-typed-array
                                (lambda is
                                  ((extend *)
                                   (apply jacobian-generator
                                          (append naked-inputs is))
                                   b))
+                               *atype*
                                (array-dimensions output)))
                              (else
                               (apply
-                               produce-array
+                               produce-typed-array
                                (lambda is
                                  (array-ref
                                   (contract-arrays
                                    (apply
-                                    produce-array
+                                    produce-typed-array
                                     (lambda ks
                                       (apply jacobian-generator
                                              (append naked-inputs is ks)))
+                                    *atype*
                                     (array-dimensions fwd))
                                    b (array-rank fwd))))
+                               *atype*
                                (array-dimensions output)))))))))
              #f jacobian-generators inputs)
             (if (number? output)
                 0
-                (apply make-array 0 (array-dimensions output)))))))))
+                (apply make-typed-array *atype* 0 (array-dimensions output)))))))))
 
 (define (ewise1 f)
   (lambda (x . indices)
@@ -374,47 +382,41 @@ element-wise. All arrays must have the same dimension."
        (/ sum count)))
    x))
 
-;;        ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30))
-;;        (let ((x #2((1 2) (3 4) (5 6))) (y #2((1 2) (3 4) (5 6)))) ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)))
 (define (amap2 f x y)
   (define (unbox-with proc x)
     (ifn (internal? x)
         x
         (proc x)))
-(define (dims-of x)
+  (define (dims-of x)
     (if (number? x)
        '()
        (array-dimensions x)))
   (define (boxed-ref x i)
-    (ifn (internal? x)
-        (array-cell-ref x i)
-        (make-internal (array-cell-ref (internal-forward x)
-                                       i)
-                       (if (internal-jacobian x)
-                           (lambda (js)
-                             (array-cell-ref
-                              ((internal-jacobian x)
-                               js)
-                              i))
-                           (lambda (js)
-                             (let* ((x (internal-forward x))
-                                    (xi (array-cell-ref x i)))
-                               (if (number? xi)
-                                   (if (= i (car js))
-                                       1
-                                       0)
-                                   (let ((a (apply make-array 0 (dims-of xi))))
-                                     (for-indices-in-range
-                                      (lambda out
-                                        (when (and (= i (car js))
-                                                   (equal? out (cdr js)))
-                                          (apply
-                                           array-set!
-                                           a 1
-                                           out)))
-                                      (list-zeros (rank-of xi))
-                                      (dims-of xi))
-                                     a))))))))
+    (let ((xi (array-cell-ref (unbox-with internal-forward x)
+                             i)))
+      (ifn (internal? x)
+          xi
+          (make-internal
+           xi
+           (if (internal-jacobian x)
+               (lambda (js)
+                 (array-cell-ref
+                  ((internal-jacobian x)
+                   js)
+                  i))
+               (lambda (js)
+                 (if (number? xi)
+                     (if (= i (car js))
+                         1
+                         0)
+                     (apply
+                      produce-typed-array
+                      (lambda indices
+                        (if (and (= i (car js))
+                                 (equal? indices (cdr js)))
+                            1
+                            0))
+                      *atype* (array-dimensions xi)))))))))
   (define (boxed-fi f i x y)
     (f (boxed-ref x i)
        (boxed-ref y i)))
@@ -423,14 +425,17 @@ element-wise. All arrays must have the same dimension."
       (let ((bs (first (array-dimensions (unbox-with internal-forward x))))
            (f0 (boxed-fi f 0 x y)))
        (ifn (internal? f0)
-            (let ((fwd (apply make-array 0 bs (dims-of f0))))
+            (let ((fwd (apply make-typed-array
+                              *atype* *unspecified* bs (dims-of f0))))
               (for-indices-in-range
                (lambda (i)
                  (array-cell-set! fwd (boxed-fi f i x y) i))
                '(0) (list bs))
               fwd)
-            (let ((jac (make-array 0 bs))
-                  (fwd (apply make-array 0 bs (dims-of (internal-forward f0)))))
+            (let ((jac (make-array *unspecified* bs))
+                  (fwd (apply make-typed-array
+                              *atype* *unspecified*
+                              bs (dims-of (internal-forward f0)))))
               (for-indices-in-range
                (lambda (i)
                  (let* ((fi (boxed-fi f i x y))
@@ -441,7 +446,9 @@ element-wise. All arrays must have the same dimension."
               (make-internal
                fwd
                (lambda (js)
-                 (let ((a (apply make-array 0 bs (dims-of (internal-forward f0)))))
+                 (let ((a (apply make-typed-array
+                                 *atype* *unspecified*
+                                 bs (dims-of (internal-forward f0)))))
                    (for-indices-in-range
                     (lambda (batch)
                       (array-cell-set!
index 72c991be59203fc458a82f03c759575c668c3111..0437d7ca425cf40c092aec21d74a66d73bb31653 100644 (file)
--- a/misc.scm
+++ b/misc.scm
@@ -10,7 +10,8 @@
    ifn
    list-zeros
    map-indexed
-   produce-array))
+   produce-array
+   produce-typed-array))
 
 (define (flip f)
   "Returns a procedure behaving as `f', but with arguments taken in reverse
@@ -57,11 +58,14 @@ order."
 
 ;;;; array utilities
 
-(define (produce-array f . dims)
-  (let ((a (apply make-array 0 dims)))
+(define (produce-typed-array f type . dims)
+  (let ((a (apply make-typed-array type *unspecified* dims)))
     (array-index-map! a f)
     a))
 
+(define (produce-array f . dims)
+  (apply produce-typed-array f #t dims))
+
 (define (array-map proc array . more)
   (let ((x (array-copy array)))
     (apply array-map! x proc array more)