From: admin Date: Sun, 5 Nov 2023 13:01:41 +0000 (+0900) Subject: Test gradient of `amap2' on non-internal inputs but internal result X-Git-Tag: v0.2.0~21 X-Git-Url: https://git.vouivredigital.com/?a=commitdiff_plain;h=20bd990137fdd4af1a7c7818850c43b7bd9f8c75;p=vouivre.git Test gradient of `amap2' on non-internal inputs but internal result --- diff --git a/grad-tests.scm b/grad-tests.scm index 4fb7a24..fd26c0a 100644 --- a/grad-tests.scm +++ b/grad-tests.scm @@ -16,6 +16,7 @@ random-array-shape random-func1 random-func2 + random-func2-rank&dims>0 random-input random-list-element random-non-empty-array @@ -272,6 +273,14 @@ and, when it's an array, at the given index." (with-generators (random-func2-rank&dims>0 gx gy) ~ (apply-grad-amap2 ngrad 1) (apply-grad-amap2 v:grad 1)))) +(test-assert + (~ ((ngrad v:amap2 1) v:* #(1 2 3) #(10 20 30)) + ((v:grad v:amap2 1) v:* #(1 2 3) #(10 20 30)))) +(test-assert + (let ((x #(10 20 30)) + (y #(10 20 30))) + (~ ((ngrad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3)) + ((v:grad (lambda (a) (v:amap2 (lambda (x y) (v:* a a)) x y))) #(1 2 3))))) ;; `v:adot' (define (random-shared-contractible) @@ -302,6 +311,13 @@ according to the number generated by the third one." ~ (lambda (a b n) ((ngrad v:adot 0) a b n)) (lambda (a b n) ((v:grad v:adot 0) a b n))))) +(test-assert + (receive (gx gy gz) (random-shared-contractible) + (with-generators + (gx gy gz) + ~ + (lambda (a b n) ((ngrad v:adot 1) a b n)) + (lambda (a b n) ((v:grad v:adot 1) a b n))))) ;; chain rule (test-assert