extend
grad
internal-jacobian
+ maximum
mean
- rank-of)
+ rank-of
+ sum)
#:replace
- ((i:* . *)
- (i:- . -)
+ ((i:sqrt . sqrt)
(i:exp . exp)
- (i:fold . fold)
+ (i:expt . expt)
+ (i:log . log)
+ (i:sin . sin)
+ (i:cos . cos)
+ (i:tan . tan)
+ (i:+ . +)
+ (i:- . -)
+ (i:* . *)
+ (i:/ . /)
+ (i:max . max)
+ (i:min . min)
+ (i:abs . abs)
(i:identity . identity)
- (i:max . max)))
+ (i:array-ref . array-ref)
+ )
+ #:re-export
+ (fold
+ reduce))
;;;; array utilities
0
(array-rank x)))
-;; (define (zeros-out:in x y)
-;; "Zeros in the shape of [x]:[y]."
-;; (cond
-;; ((and (number? x)
-;; (number? y))
-;; 0)
-;; ((and (array? x)
-;; (array? y))
-;; (apply make-array 0 (append (array-dimensions x)
-;; (array-dimensions y))))
-;; ((and (number? x)
-;; (array? y))
-;; (apply make-array 0 (array-dimensions y)))
-;; ((and (array? x)
-;; (number? y))
-;; (apply make-array 0 (array-dimensions x)))
-;; (else (error "undefined." x y))))
-
;;;; differentiation
(define-record-type internal
(forward internal-forward)
(jacobian internal-jacobian))
-(define *atype* 'f32)
+;; (define *atype* 'f64)
+(define *atype* #t)
(define *grad* (make-parameter #f))
(define* (grad f #:optional (axis 0))
((ifn prev identity (lambda (x) ((extend +) prev x)))
(ifn (internal-jacobian input)
(if (number? output)
- (apply jacobian-generator
- (append naked-inputs js))
+ (jacobian-generator naked-inputs js)
(apply
produce-typed-array
(lambda is
- (apply jacobian-generator
- (append naked-inputs is js)))
+ (jacobian-generator naked-inputs (append is js)))
*atype*
(array-dimensions output)))
(let ((b ((internal-jacobian input) js))
(fwd (internal-forward input)))
(cond
((and (number? output) (number? fwd))
- (* (apply jacobian-generator naked-inputs)
+ (* (jacobian-generator naked-inputs '())
b))
((and (number? output) (array? fwd))
(array-ref
(apply
produce-typed-array
(lambda ks
- (apply jacobian-generator
- (append naked-inputs ks)))
+ (jacobian-generator naked-inputs ks))
*atype*
(array-dimensions fwd))
b (array-rank fwd))))
produce-typed-array
(lambda is
((extend *)
- (apply jacobian-generator
- (append naked-inputs is))
+ (jacobian-generator naked-inputs is)
b))
*atype*
(array-dimensions output)))
(apply
produce-typed-array
(lambda ks
- (apply jacobian-generator
- (append naked-inputs is ks)))
+ (jacobian-generator naked-inputs (append is ks)))
*atype*
(array-dimensions fwd))
b (array-rank fwd))))
(apply make-typed-array *atype* 0 (array-dimensions output)))))))))
(define (ewise1 f)
- (lambda (x . indices)
- (if (number? x)
- (f x)
- (receive (is js) (split-at indices (array-rank x))
- (if (equal? is js)
- (f (apply array-ref x is))
- 0)))))
+ (lambda (xs indices)
+ (let ((x (car xs)))
+ (if (number? x)
+ (f x)
+ (receive (is js) (split-at indices (array-rank x))
+ (if (equal? is js)
+ (f (apply array-ref x is))
+ 0))))))
(define (ewise2 proc axis)
- (lambda (x y . indices)
- (cond
- ((and (number? x) (number? y))
- (proc x y))
- ((and (number? x) (array? y))
- (if (= axis 0)
- (proc x (apply array-ref y indices))
- (receive (is js) (split-at indices (array-rank y))
- (ifn (equal? is js)
- 0
- (proc x (apply array-ref y is))))))
- ((and (array? x) (number? y))
- (if (= axis 1)
- (proc (apply array-ref x indices) y)
- (receive (is js) (split-at indices (array-rank x))
- (ifn (equal? is js)
- 0
- (proc (apply array-ref x is)
- y)))))
- (else
- (receive (is js) (split-at indices (array-rank x))
- (ifn (equal? is js)
- 0
- (proc (apply array-ref x is)
- (apply array-ref y is))))))))
+ (lambda (xs indices)
+ (let ((x (car xs))
+ (y (cadr xs)))
+ (cond
+ ((and (number? x) (number? y))
+ (proc x y))
+ ((and (number? x) (array? y))
+ (if (= axis 0)
+ (proc x (apply array-ref y indices))
+ (receive (is js) (split-at indices (array-rank y))
+ (ifn (equal? is js)
+ 0
+ (proc x (apply array-ref y is))))))
+ ((and (array? x) (number? y))
+ (if (= axis 1)
+ (proc (apply array-ref x indices) y)
+ (receive (is js) (split-at indices (array-rank x))
+ (ifn (equal? is js)
+ 0
+ (proc (apply array-ref x is)
+ y)))))
+ (else
+ (receive (is js) (split-at indices (array-rank x))
+ (ifn (equal? is js)
+ 0
+ (proc (apply array-ref x is)
+ (apply array-ref y is)))))))))
(define (i:identity x)
"Differentiable identity."
identity
x))
+(define (i:sqrt x)
+ "Differentiable square root."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (/ 1 2 (sqrt x)))))
+ (extend sqrt)
+ x))
+
(define (i:exp x)
"Differentiable exponential."
(differentiable-wrapper
(extend exp)
x))
-(define (i:* x y)
- "Differentiable element-wise multiplication."
+(define (i:expt x y)
+ "Differentiable power."
+ (differentiable-wrapper
+ (list (ewise2 (lambda (x y) (* y (expt x (1- y)))) 0)
+ (ewise2 (lambda (x y) (* (expt x y) (log x))) 1))
+ (extend expt)
+ x y))
+
+(define (i:log x)
+ "Differentiable logarithm."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (/ x))))
+ (extend log)
+ x))
+
+(define (i:sin x)
+ "Differentiable sine."
+ (differentiable-wrapper
+ (list (ewise1 cons))
+ (extend sin)
+ x))
+
+(define (i:cos x)
+ "Differentiable cosine."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (- (sin x)))))
+ (extend cos)
+ x))
+
+(define (i:tan x)
+ "Differentiable tangent."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x) (/ (expt (cos x) 2)))))
+ (extend tan)
+ x))
+
+(define (i:+ x y)
+ "Differentiable element-wise addition."
(differentiable-wrapper
(list
- (ewise2 (lambda (x y) y) 0)
- (ewise2 (lambda (x y) x) 1))
- (extend *)
+ (ewise2 (lambda _ +1) 0)
+ (ewise2 (lambda _ +1) 1))
+ (extend +)
x y))
(define (i:- x y)
(extend -)
x y))
+(define (i:* x y)
+ "Differentiable element-wise multiplication."
+ (differentiable-wrapper
+ (list
+ (ewise2 (lambda (x y) y) 0)
+ (ewise2 (lambda (x y) x) 1))
+ (extend *)
+ x y))
+
+(define (i:/ x y)
+ "Differentiable element-wise division."
+ (differentiable-wrapper
+ (list
+ (ewise2 (lambda (x y) (/ y)) 0)
+ (ewise2 (lambda (x y) (- (/ x y y))) 1))
+ (extend /)
+ x y))
+
(define (i:max x y)
"Differentiable element-wise maximum."
(define (dmax x y)
(extend max)
x y))
+(define (i:min x y)
+ "Differentiable element-wise minimum."
+ (define (dmin x y)
+ (cond
+ ((< x y)
+ 1)
+ ((= x y)
+ 1/2)
+ (else
+ 0)))
+ (differentiable-wrapper
+ (list
+ (ewise2 dmin 0)
+ (ewise2 (flip dmin) 1))
+ (extend min)
+ x y))
+
+(define (i:abs x)
+ "Differentiable absolute."
+ (differentiable-wrapper
+ (list (ewise1 (lambda (x)
+ (cond ((> x 0)
+ +1)
+ ((= x 0)
+ 1/2)
+ ((< x 0)
+ -1)))))
+ (extend abs)
+ x))
+
(define (mean x)
"Differentiable mean on arrays."
(differentiable-wrapper
(list
- (lambda (x . indices)
- (/ 1 (apply * (array-dimensions x)))))
+ (lambda (xs indices)
+ (/ 1 (apply * (array-dimensions (car xs))))))
(lambda (x)
(let ((sum 0)
(count 0))
(/ sum count)))
x))
+(define (i:array-ref x . indices)
+ "Differentiable array-ref w.r.t `x'."
+ (apply
+ differentiable-wrapper
+ (cons
+ (lambda (xs js)
+ (if (equal? js indices)
+ 1
+ 0))
+ (list-tabulate (length indices) not))
+ array-ref
+ x indices))
+
+(define (for-flat-index proc a)
+ (let* ((b (array-contents a))
+ (n (array-length b)))
+ (let rec ((i 0))
+ (unless (= i n)
+ (proc b i)
+ (rec (1+ i))))))
+
+(define (maximum x)
+ "Differentiable maximum on arrays."
+ (differentiable-wrapper
+ (list
+ (lambda (xs indices)
+ (let ((x (car xs))
+ (m (- (inf)))
+ (i 'TBD))
+ (for-indices-in-range
+ (lambda indices
+ (let ((xi (apply array-ref x indices)))
+ (when (< m xi)
+ (set! m xi)
+ (set! i indices))))
+ (list-zeros (array-rank x))
+ (array-dimensions x))
+ (if (equal? i indices)
+ 1
+ 0))))
+ (lambda (x)
+ (let ((m (- (inf))))
+ (array-for-each
+ (lambda (x)
+ (set! m (max m x)))
+ x)
+ m))
+ x))
+
+(define (sum x)
+ "Differentiable sum on arrays."
+ (differentiable-wrapper
+ (list
+ (lambda (xs indices)
+ 1))
+ (lambda (x)
+ (let ((sum 0))
+ (array-for-each
+ (lambda (x)
+ (set! sum (+ sum x)))
+ x)
+ sum))
+ x))
+
(define (amap2 f x y)
(define (unbox-with proc x)
(ifn (internal? x)
(equal? indices (cdr js)))
1
0))
- *atype* (array-dimensions xi)))))))))
+ (array-type xi)
+ (array-dimensions xi)))))))))
(define (boxed-fi f i x y)
(f (boxed-ref x i)
(boxed-ref y i)))
(f0 (boxed-fi f 0 x y)))
(ifn (internal? f0)
(let ((fwd (apply make-typed-array
+ ;; TODO: use the correct type based on f0.
*atype* *unspecified* bs (dims-of f0))))
(for-indices-in-range
(lambda (i)
fwd)
(let ((jac (make-array *unspecified* bs))
(fwd (apply make-typed-array
+ ;; TODO: use the correct type based on f0.
*atype* *unspecified*
bs (dims-of (internal-forward f0)))))
(for-indices-in-range
fwd
(lambda (js)
(let ((a (apply make-typed-array
+ ;; TODO: use the correct type based on f0.
*atype* *unspecified*
bs (dims-of (internal-forward f0)))))
(for-indices-in-range
(define (adot x y n)
(differentiable-wrapper
(list
- (lambda (x y n . indices)
- (let* ((free-rank-x (- (array-rank x)
+ (lambda (xs indices)
+ (let* ((x (car xs))
+ (y (cadr xs))
+ (n (caddr xs))
+ (free-rank-x (- (array-rank x)
n))
(free-rank-y (- (array-rank y)
n))
(ifn (equal? is-free1 is-free2)
0
(apply array-ref y (append is-bound js-free)))))
- (lambda (x y n . indices)
- (let* ((free-rank-x (- (array-rank x)
+ (lambda (xs indices)
+ (let* ((x (car xs))
+ (y (cadr xs))
+ (n (caddr xs))
+ (free-rank-x (- (array-rank x)
n))
(free-rank-y (- (array-rank y)
n))