#:use-module ((guile) #:select (1+) #:prefix guile:)
#:use-module ((rnrs base) #:prefix rnrs:)
#:use-module ((srfi srfi-1) #:prefix srfi-1:)
+ #:use-module ((vouivre autodiff) #:prefix v:)
#:use-module (vouivre curry)
#:export
(∘
⊙
- flip
fdiff
+ flip
rdiff)
#:replace
(boolean?
reduce-right
map
1+
- identity))
+ identity
+ amap2
+ adot
+ logsumexp
+ maximum
+ mean
+ relu
+ sum
+ fdiff
+ rdiff))
;; abbreviation
(define-syntax cudefine (identifier-syntax curried-untyped-define))
(cudefine (angle x) (rnrs:angle x))
(∷ sqrt (0 . 0))
-(cudefine (sqrt x) (rnrs:sqrt x))
+(cudefine (sqrt x) (v:sqrt x))
(∷ exp (0 . 0))
-(cudefine (exp x) (rnrs:exp x))
+(cudefine (exp x) (v:exp x))
(∷ expt (0 . (0 . 0)))
-(cudefine (expt x y) (rnrs:expt x y))
+(cudefine (expt x y) (v:expt x y))
(∷ log (0 . 0))
-(cudefine (log x) (rnrs:log x))
+(cudefine (log x) (v:log x))
(∷ sin (0 . 0))
-(cudefine (sin x) (rnrs:sin x))
+(cudefine (sin x) (v:sin x))
(∷ cos (0 . 0))
-(cudefine (cos x) (rnrs:cos x))
+(cudefine (cos x) (v:cos x))
(∷ tan (0 . 0))
-(cudefine (tan x) (rnrs:tan x))
+(cudefine (tan x) (v:tan x))
(∷ asin (0 . 0))
(cudefine (asin x) (rnrs:asin x))
(cudefine (string-append x y) (rnrs:string-append x y))
(∷ + (0 . (0 . 0)))
-(cudefine (+ x y) (rnrs:+ x y))
+(cudefine (+ x y) (v:+ x y))
(∷ - (0 . (0 . 0)))
-(cudefine (- x y) (rnrs:- x y))
+(cudefine (- x y) (v:- x y))
(∷ * (0 . (0 . 0)))
-(cudefine (* x y) (rnrs:* x y))
+(cudefine (* x y) (v:* x y))
(∷ / (0 . (0 . 0)))
-(cudefine (/ x y) (rnrs:/ x y))
+(cudefine (/ x y) (v:/ x y))
(∷ max (0 . (0 . 0)))
-(cudefine (max x y) (rnrs:max x y))
+(cudefine (max x y) (v:max x y))
(∷ min (0 . (0 . 0)))
-(cudefine (min x y) (rnrs:min x y))
+(cudefine (min x y) (v:min x y))
(∷ abs (0 . 0))
-(cudefine (abs x) (rnrs:abs x))
+(cudefine (abs x) (v:abs x))
(∷ truncate (0 . 0))
(cudefine (truncate x) (rnrs:truncate x))
(definec (∘ g f) (λc x (g (f x))))
(definec (⊙ f g) (∘ g f))
(definec (flip f) (λc y (λc x (f x y))))
+
+(∷ adot (0 . (0 . (0 . 0))))
+(cudefine (adot x y n) (v:adot x y n))
+
+(∷ amap2 ((0 . (0 . 0)) . (0 . (0 . 0))))
+(cudefine (amap2 f x y) (v:amap2 f x y))
+
+(∷ logsumexp (0 . 0))
+(cudefine (logsumexp x) (v:logsumexp x))
+
+(∷ maximum (0 . 0))
+(cudefine (maximum x) (v:maximum x))
+
+(∷ mean (0 . 0))
+(cudefine (mean x) (v:mean x))
+
+(∷ relu (0 . 0))
+(cudefine (relu x) (v:relu x))
+
+(∷ sum (0 . 0))
+(cudefine (sum x) (v:sum x))
+
+(∷ fdiff ((0 . 0) . (0 . 0)))
+(cudefine (fdiff x) (v:fdiff x))
+
+(∷ rdiff ((0 . 0) . (0 . 0)))
+(cudefine (rdiff x) (v:rdiff x))