From 5295b195318d8a8b700d11903ffb39fe44f3d455 Mon Sep 17 00:00:00 2001 From: admin Date: Mon, 20 Nov 2023 17:00:56 +0900 Subject: [PATCH] Catch offending generated input(s) on error --- grad-tests.scm | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/grad-tests.scm b/grad-tests.scm index 77b3ff0..bdf8fd4 100644 --- a/grad-tests.scm +++ b/grad-tests.scm @@ -23,11 +23,12 @@ random-shape random-shared random-shared-array-rank&dims>0 + random-shared-contractible with-generators ~)) (define f1s (list v:exp v:identity)) -(define f2s (list v:* v:- v:max)) +(define f2s (list v:+ v:- v:* v:max v:min)) (define-syntax-rule (with-generators (g1 g2 ...) equal expected given) (let ((times 100) @@ -39,10 +40,15 @@ (do ((i 0 (1+ i))) ((= i times) #t) (let ((zs (map-in-order (lambda (g) (g)) generators))) - (let ((r1 (apply fx zs)) - (r2 (apply fy zs))) - (unless (equal r1 r2) - (break #f zs r1 r2))))))))) + (with-exception-handler + (lambda (e) + (break #f zs)) + (lambda () + (let ((r1 (apply fx zs)) + (r2 (apply fy zs))) + (unless (equal r1 r2) + (break #f zs r1 r2)))) + #:unwind? #t))))))) (define (lambda-const-call f . consts) (lambda _ @@ -135,7 +141,7 @@ x y) #t)))) -(define* (~ x y #:optional (error 5e-2)) +(define* (~ x y #:optional (error 1e-4)) (cond ((and (number? x) (number? y)) (n~ x y error)) @@ -143,7 +149,7 @@ (a~ x y error)) (else #f))) -(define* (ngrad f #:optional (axis 0) (step 1e-4)) +(define* (ngrad f #:optional (axis 0) (step 1e-6)) "Gradient using a numerical centered difference approximation." (define (axis-add xs dh . indices) "Add `dh' to the number or array at the given `axis' of `xs', -- 2.39.5