]> git.vouivredigital.com Git - vouivre.git/commitdiff
Test gradient of `amap2' on non-internal inputs but internal result
authoradmin <admin@vouivredigital.com>
Sun, 5 Nov 2023 13:01:41 +0000 (22:01 +0900)
committeradmin <admin@vouivredigital.com>
Sun, 5 Nov 2023 13:01:41 +0000 (22:01 +0900)
grad-tests.scm

index 4fb7a24309f7e0654208d7bc692a4e0f8fd4f65d..fd26c0a6fe0592dad18cd339b96f9a0038199901 100644 (file)
@@ -16,6 +16,7 @@
    random-array-shape
    random-func1
    random-func2
+   random-func2-rank&dims>0
    random-input
    random-list-element
    random-non-empty-array
@@ -272,6 +273,14 @@ and, when it's an array, at the given index."
       (with-generators
        (random-func2-rank&dims>0 gx gy)
        ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1))))
+(test-assert
+    (~ ((ngrad v:amap2 1) v:* #(1 2 3) #(10 20 30))
+       ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30))))
+(test-assert
+    (let ((x #(10 20 30))
+         (y #(10 20 30)))
+      (~ ((ngrad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))
+        ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)))))
 
 ;; `v:adot'
 (define (random-shared-contractible)
@@ -302,6 +311,13 @@ according to the number generated by the third one."
        ~
        (lambda (a b n) ((ngrad v:adot 0) a b n))
        (lambda (a b n) ((v:grad v:adot 0) a b n)))))
+(test-assert
+    (receive (gx gy gz) (random-shared-contractible)
+      (with-generators
+       (gx gy gz)
+       ~
+       (lambda (a b n) ((ngrad v:adot 1) a b n))
+       (lambda (a b n) ((v:grad v:adot 1) a b n)))))
 
 ;; chain rule
 (test-assert