From 7c78f78b27207cc10b726485e97abc14c9ed6c9a Mon Sep 17 00:00:00 2001 From: admin Date: Tue, 28 Nov 2023 03:32:15 +0900 Subject: [PATCH] Make differentiable functions partially applicable --- examples/base.scm | 69 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/examples/base.scm b/examples/base.scm index dc637cf..1df1022 100644 --- a/examples/base.scm +++ b/examples/base.scm @@ -19,12 +19,13 @@ #:use-module ((guile) #:select (1+) #:prefix guile:) #:use-module ((rnrs base) #:prefix rnrs:) #:use-module ((srfi srfi-1) #:prefix srfi-1:) + #:use-module ((vouivre autodiff) #:prefix v:) #:use-module (vouivre curry) #:export (∘ ⊙ - flip fdiff + flip rdiff) #:replace (boolean? @@ -163,7 +164,16 @@ reduce-right map 1+ - identity)) + identity + amap2 + adot + logsumexp + maximum + mean + relu + sum + fdiff + rdiff)) ;; abbreviation (define-syntax cudefine (identifier-syntax curried-untyped-define)) @@ -352,25 +362,25 @@ (cudefine (angle x) (rnrs:angle x)) (∷ sqrt (0 . 0)) -(cudefine (sqrt x) (rnrs:sqrt x)) +(cudefine (sqrt x) (v:sqrt x)) (∷ exp (0 . 0)) -(cudefine (exp x) (rnrs:exp x)) +(cudefine (exp x) (v:exp x)) (∷ expt (0 . (0 . 0))) -(cudefine (expt x y) (rnrs:expt x y)) +(cudefine (expt x y) (v:expt x y)) (∷ log (0 . 0)) -(cudefine (log x) (rnrs:log x)) +(cudefine (log x) (v:log x)) (∷ sin (0 . 0)) -(cudefine (sin x) (rnrs:sin x)) +(cudefine (sin x) (v:sin x)) (∷ cos (0 . 0)) -(cudefine (cos x) (rnrs:cos x)) +(cudefine (cos x) (v:cos x)) (∷ tan (0 . 0)) -(cudefine (tan x) (rnrs:tan x)) +(cudefine (tan x) (v:tan x)) (∷ asin (0 . 0)) (cudefine (asin x) (rnrs:asin x)) @@ -505,25 +515,25 @@ (cudefine (string-append x y) (rnrs:string-append x y)) (∷ + (0 . (0 . 0))) -(cudefine (+ x y) (rnrs:+ x y)) +(cudefine (+ x y) (v:+ x y)) (∷ - (0 . (0 . 0))) -(cudefine (- x y) (rnrs:- x y)) +(cudefine (- x y) (v:- x y)) (∷ * (0 . (0 . 0))) -(cudefine (* x y) (rnrs:* x y)) +(cudefine (* x y) (v:* x y)) (∷ / (0 . (0 . 0))) -(cudefine (/ x y) (rnrs:/ x y)) +(cudefine (/ x y) (v:/ x y)) (∷ max (0 . (0 . 0))) -(cudefine (max x y) (rnrs:max x y)) +(cudefine (max x y) (v:max x y)) (∷ min (0 . (0 . 0))) -(cudefine (min x y) (rnrs:min x y)) +(cudefine (min x y) (v:min x y)) (∷ abs (0 . 0)) -(cudefine (abs x) (rnrs:abs x)) +(cudefine (abs x) (v:abs x)) (∷ truncate (0 . 0)) (cudefine (truncate x) (rnrs:truncate x)) @@ -605,3 +615,30 @@ (definec (∘ g f) (λc x (g (f x)))) (definec (⊙ f g) (∘ g f)) (definec (flip f) (λc y (λc x (f x y)))) + +(∷ adot (0 . (0 . (0 . 0)))) +(cudefine (adot x y n) (v:adot x y n)) + +(∷ amap2 ((0 . (0 . 0)) . (0 . (0 . 0)))) +(cudefine (amap2 f x y) (v:amap2 f x y)) + +(∷ logsumexp (0 . 0)) +(cudefine (logsumexp x) (v:logsumexp x)) + +(∷ maximum (0 . 0)) +(cudefine (maximum x) (v:maximum x)) + +(∷ mean (0 . 0)) +(cudefine (mean x) (v:mean x)) + +(∷ relu (0 . 0)) +(cudefine (relu x) (v:relu x)) + +(∷ sum (0 . 0)) +(cudefine (sum x) (v:sum x)) + +(∷ fdiff ((0 . 0) . (0 . 0))) +(cudefine (fdiff x) (v:fdiff x)) + +(∷ rdiff ((0 . 0) . (0 . 0))) +(cudefine (rdiff x) (v:rdiff x)) -- 2.39.5