;;;; array utilities
+(define (abs->rel index dimensions)
+ (let rec ((ds (cdr dimensions))
+ (p (apply * (cdr dimensions)))
+ (r index)
+ (is '()))
+ (let ((i (quotient r p)))
+ (if (null? ds)
+ (reverse (cons i is))
+ (rec (cdr ds)
+ (quotient p (car ds))
+ (- r (* p i))
+ (cons i is))))))
+
(define (contracted-dims a b n)
(let ((dims-a (array-dimensions a))
(dims-b (array-dimensions b)))
;; corresponding generator should be `#f'.
(let* ((inputs (cons input more))
(naked-inputs (map unbox-fwd inputs))
- (output (apply function naked-inputs)))
+ (output (apply (if (procedure? function)
+ function
+ (car function))
+ naked-inputs))
+ (data (if (procedure? function)
+ '()
+ (map (lambda (f) (apply f naked-inputs))
+ (cdr function)))))
(ifn (*grad*)
output
(make-internal
((ifn prev identity (lambda (x) ((extend +) prev x)))
(ifn (internal-jacobian input)
(if (number? output)
- (jacobian-generator naked-inputs js)
+ (apply jacobian-generator naked-inputs js data)
(apply
produce-typed-array
(lambda is
- (jacobian-generator naked-inputs (append is js)))
+ (apply jacobian-generator
+ naked-inputs
+ (append is js)
+ data))
*atype*
(array-dimensions output)))
(let ((b ((internal-jacobian input) js))
(fwd (internal-forward input)))
(cond
((and (number? output) (number? fwd))
- (* (jacobian-generator naked-inputs '())
+ (* (apply jacobian-generator
+ naked-inputs '() data)
b))
((and (number? output) (array? fwd))
(array-ref
(apply
produce-typed-array
(lambda ks
- (jacobian-generator naked-inputs ks))
+ (apply jacobian-generator
+ naked-inputs ks data))
*atype*
(array-dimensions fwd))
b (array-rank fwd))))
produce-typed-array
(lambda is
((extend *)
- (jacobian-generator naked-inputs is)
+ (apply jacobian-generator
+ naked-inputs is data)
b))
*atype*
(array-dimensions output)))
(apply
produce-typed-array
(lambda ks
- (jacobian-generator naked-inputs (append is ks)))
+ (apply jacobian-generator
+ naked-inputs
+ (append is ks)
+ data))
*atype*
(array-dimensions fwd))
b (array-rank fwd))))
"Differentiable mean on arrays."
(differentiable-wrapper
(list
- (lambda (xs indices)
- (/ 1 (apply * (array-dimensions (car xs))))))
- (lambda (x)
- (let ((sum 0)
- (count 0))
- (array-for-each
- (lambda (x)
- (set! sum (+ sum x))
- (set! count (1+ count)))
- x)
- (/ sum count)))
+ (lambda (xs indices one-over-n)
+ one-over-n))
+ (let ((n 0))
+ (list
+ (lambda (x)
+ (let ((sum 0))
+ (array-for-each
+ (lambda (x)
+ (set! sum (+ sum x))
+ (set! n (1+ n)))
+ x)
+ (/ sum n)))
+ (lambda _ (/ n))))
x))
(define (i:array-ref x . indices)
array-ref
x indices))
-(define (for-flat-index proc a)
- (let* ((b (array-contents a))
- (n (array-length b)))
- (let rec ((i 0))
- (unless (= i n)
- (proc b i)
- (rec (1+ i))))))
-
(define (maximum x)
"Differentiable maximum on arrays."
(differentiable-wrapper
(list
- (lambda (xs indices)
- (let ((x (car xs))
- (m (- (inf)))
- (i 'TBD))
- (for-indices-in-range
- (lambda indices
- (let ((xi (apply array-ref x indices)))
- (when (< m xi)
- (set! m xi)
- (set! i indices))))
- (list-zeros (array-rank x))
- (array-dimensions x))
- (if (equal? i indices)
- 1
- 0))))
- (lambda (x)
- (let ((m (- (inf))))
- (array-for-each
- (lambda (x)
- (set! m (max m x)))
- x)
- m))
+ (lambda (xs indices max-indices)
+ (if (equal? indices max-indices)
+ 1
+ 0)))
+ (let ((max-index 'TBD))
+ (list
+ (lambda (x)
+ (let ((m (- (inf)))
+ (i 0))
+ (array-for-each
+ (lambda (x)
+ (when (< m x)
+ (set! m x)
+ (set! max-index i))
+ (set! i (1+ i)))
+ x)
+ m))
+ (lambda (x)
+ (abs->rel max-index (array-dimensions x)))))
x))
(define (sum x)
"Differentiable sum on arrays."
(differentiable-wrapper
- (list
- (lambda (xs indices)
- 1))
+ (list (lambda _ 1))
(lambda (x)
(let ((sum 0))
(array-for-each
(define (adot x y n)
(differentiable-wrapper
(list
- (lambda (xs indices)
- (let* ((x (car xs))
- (y (cadr xs))
+ (lambda (xs indices free-rank-x free-rank-y)
+ (let* ((y (cadr xs))
(n (caddr xs))
- (free-rank-x (- (array-rank x)
- n))
- (free-rank-y (- (array-rank y)
- n))
(out (take indices (+ free-rank-x free-rank-y)))
(in (drop indices (+ free-rank-x free-rank-y)))
(is-free1 (take out free-rank-x))
(ifn (equal? is-free1 is-free2)
0
(apply array-ref y (append is-bound js-free)))))
- (lambda (xs indices)
+ (lambda (xs indices free-rank-x free-rank-y)
(let* ((x (car xs))
- (y (cadr xs))
(n (caddr xs))
- (free-rank-x (- (array-rank x)
- n))
- (free-rank-y (- (array-rank y)
- n))
(out (take indices (+ free-rank-x free-rank-y)))
(in (drop indices (+ free-rank-x free-rank-y)))
(is-free (take out free-rank-x))
0
(apply array-ref x (append is-free js-bound)))))
#f)
- contract-arrays x y n))
+ (list
+ contract-arrays
+ (lambda (x y n) (- (array-rank x) n))
+ (lambda (x y n) (- (array-rank y) n)))
+ x y n))