]> git.vouivredigital.com Git - vouivre.git/commitdiff
Catch offending generated input(s) on error
authoradmin <admin@vouivredigital.com>
Mon, 20 Nov 2023 08:00:56 +0000 (17:00 +0900)
committeradmin <admin@vouivredigital.com>
Mon, 20 Nov 2023 08:00:56 +0000 (17:00 +0900)
grad-tests.scm

index 77b3ff0afa1f15c3370a83833ddcc8e11a039d58..bdf8fd4e905d650bc67caafacebfb45898f7c4ca 100644 (file)
    random-shape
    random-shared
    random-shared-array-rank&dims>0
+   random-shared-contractible
    with-generators
    ~))
 
 (define f1s (list v:exp v:identity))
-(define f2s (list v:* v:- v:max))
+(define f2s (list v:+ v:- v:* v:max v:min))
 
 (define-syntax-rule (with-generators (g1 g2 ...) equal expected given)
   (let ((times 100)
        (do ((i 0 (1+ i)))
           ((= i times) #t)
         (let ((zs (map-in-order (lambda (g) (g)) generators)))
-          (let ((r1 (apply fx zs))
-                (r2 (apply fy zs)))
-            (unless (equal r1 r2)
-              (break #f zs r1 r2)))))))))
+          (with-exception-handler
+              (lambda (e)
+                (break #f zs))
+            (lambda ()
+              (let ((r1 (apply fx zs))
+                    (r2 (apply fy zs)))
+                (unless (equal r1 r2)
+                  (break #f zs r1 r2))))
+            #:unwind? #t)))))))
 
 (define (lambda-const-call f . consts)
   (lambda _
       x y)
      #t))))
 
-(define* (~ x y #:optional (error 5e-2))
+(define* (~ x y #:optional (error 1e-4))
   (cond
    ((and (number? x) (number? y))
     (n~ x y error))
     (a~ x y error))
    (else #f)))
 
-(define* (ngrad f #:optional (axis 0) (step 1e-4))
+(define* (ngrad f #:optional (axis 0) (step 1e-6))
   "Gradient using a numerical centered difference approximation."
   (define (axis-add xs dh . indices)
     "Add `dh' to the number or array at the given `axis' of `xs',