#:use-module (srfi srfi-1)
#:use-module (srfi srfi-9)
#:use-module (vouivre misc)
+ #:use-module (vouivre promises)
#:export
(*atype*
adot
contract-arrays
differentiable-wrapper
dot
+ do-times
+ ewise1
+ ewise2
extend
grad
- internal-jacobian
+ make-batch
+ make-internal
maximum
mean
rank-of
- sum
- )
+ sum)
#:replace
((i:sqrt . sqrt)
(i:exp . exp)
(i:abs . abs)
(i:identity . identity)
(i:array-ref . array-ref)
- )
+ (i:array-cell-ref . array-cell-ref))
#:re-export
(fold
reduce))
;;;; array utilities
-(define (abs->rel index dimensions)
- (let rec ((ds (cdr dimensions))
- (p (apply * (cdr dimensions)))
- (r index)
- (is '()))
- (let ((i (quotient r p)))
- (if (null? ds)
- (reverse (cons i is))
- (rec (cdr ds)
- (quotient p (car ds))
- (- r (* p i))
- (cons i is))))))
-
-(define (contracted-dims a b n)
- (let ((dims-a (array-dimensions a))
- (dims-b (array-dimensions b)))
- (if (or (> n (array-rank a))
- (> n (array-rank b)))
- (error "can't contract arrays with size lower than" n)
- (append (take dims-a (- (array-rank a)
- n))
- (drop dims-b n)))))
+(define (rel->abs indices dimensions)
+ (let rec ((s 0)
+ (p 1)
+ (is (reverse indices))
+ (ds (reverse dimensions)))
+ (if (null? is)
+ s
+ (rec (+ s (* p (car is)))
+ (* p (car ds))
+ (cdr is)
+ (cdr ds)))))
(define(do-times n proc)
(let rec ((i 0))
;;(define *atype* 'f32)
(define *atype* #t)
(define *grad* (make-parameter #f))
+(define *j* (make-parameter #f))
+
+(define-syntax-rule (w/j val body ...)
+ (parameterize ((*j* val))
+ body ...))
+
+(define (unwrap-fwd x)
+ (if (internal? x)
+ (internal-forward x)
+ x))
+
+(define (unwrap-jac x)
+ (if (internal? x)
+ (internal-jacobian x)
+ x))
+
+(define (dims-of x)
+ (if (number? x)
+ '()
+ (array-dimensions x)))
+
+(define (mov dst-buf n-dst-dims generator naked-inputs data j)
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (array-set!
+ dst-buf
+ (apply generator naked-inputs i j data)
+ i)))
+ dst-buf)
+
+(define (add dst-buf n-dst-dims generator naked-inputs data j)
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (array-set!
+ dst-buf
+ (+ (array-ref dst-buf i)
+ (apply generator naked-inputs i j data))
+ i)))
+ dst-buf)
+
+(define (movc dst-buf n-dst-dims src-buf n-src-dims
+ generator naked-inputs data)
+ "Contract the Jacobian column produced by the generator with the source buffer
+storing the result in the destination buffer."
+ (let ((s 0))
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (set! s 0)
+ (do-times
+ n-src-dims
+ (lambda (k)
+ (set! s (+ s (* (apply generator naked-inputs i k data)
+ (array-ref src-buf k))))))
+ (array-set! dst-buf s i))))
+ dst-buf)
+
+(define (addc dst-buf n-dst-dims src-buf n-src-dims
+ generator naked-inputs data)
+ "Contract the Jacobian column produced by the generator with the source buffer
+adding the result to the destination buffer."
+ (let ((s 0))
+ (do-times
+ n-dst-dims
+ (lambda (i)
+ (set! s (array-ref dst-buf i))
+ (do-times
+ n-src-dims
+ (lambda (k)
+ (set! s (+ s (* (apply generator naked-inputs i k data)
+ (array-ref src-buf k))))))
+ (array-set! dst-buf s i))))
+ dst-buf)
(define* (grad f #:optional (axis 0))
(define (wrap x i)
(if (= i axis)
- (make-internal x #f)
+ (make-internal x 'input)
x))
(lambda xs
- (let* ((wrapped-xs (map wrap xs (list-tabulate (length xs) identity)))
- (y (internal-forward (list-ref wrapped-xs axis))))
- ((if (*grad*)
- identity
- (lambda (boxed)
- (let ((fx (internal-forward boxed))
- (gfx (internal-jacobian boxed)))
- (cond
- ((and (number? fx) (number? y))
- (gfx '()))
- ((and (number? fx) (array? y))
- (apply produce-typed-array
- (lambda js (gfx js))
- *atype*
- (array-dimensions y)))
- ((and (array? fx) (number? y))
- (gfx '()))
- (else
- (let* ((dims-out (array-dimensions fx))
- (dims-in (array-dimensions y))
- (a (apply make-typed-array *atype* *unspecified* (append dims-out dims-in))))
- (for-indices-in-range
- (lambda js
- (let ((out (gfx js)))
- (for-indices-in-range
- (lambda is
- (apply
- array-set!
- a
- (apply array-ref out is)
- (append is js)))
- (list-zeros (array-rank fx))
- dims-out)))
- (list-zeros (array-rank y))
- dims-in)
- a))))))
- (parameterize (((@@ (vouivre grad) *grad*) y))
- (apply f wrapped-xs))))))
-
-(define (unbox-fwd x)
- (if (internal? x)
- (internal-forward x)
- x))
-
-;; `n' is the number of arguments to `function'.
-;; `jacobian-generators is not a `Vec' but a `List' we only use the former to
-;; show its length.
+ (parameterize (((@@ (vouivre grad) *grad*) #t)
+ ((@@ (vouivre grad) *promises*) (cons '() #f)))
+ (let* ((internal (apply f (map-indexed wrap xs)))
+ (fx (internal-forward internal))
+ (y (list-ref xs axis)) ; variable to differentiate w.r.t
+ (pre-Jx (internal-jacobian internal))
+ (Jx (cond
+ ;; TODO: implement 'input case and test 'zero and 'input
+ ((eq? pre-Jx 'zero)
+ (lambda (j)
+ (lambda (i)
+ 0)))
+ ((eq? pre-Jx 'input)
+ (error "TBD."))
+ (else
+ (lambda (j)
+ (reset-promises (car (*promises*)))
+ (let ((column-jac (w/j j (force pre-Jx))))
+ (lambda (i)
+ (array-ref column-jac i))))))))
+ (cond
+ ((and (number? fx) (number? y))
+ ((Jx 0) 0))
+ ((and (number? fx) (array? y))
+ (let* ((y-dims (array-dimensions y))
+ (a (apply make-array *unspecified* y-dims))
+ (ac (array-contents a)))
+ (do-times
+ (apply * y-dims)
+ (lambda (j)
+ (array-set! ac ((Jx j) 0)
+ j)))
+ a))
+ ((and (array? fx) (number? y))
+ (let* ((fx-dims (array-dimensions fx))
+ (a (apply make-array *unspecified* fx-dims))
+ (ac (array-contents a))
+ (Jx (Jx 0)))
+ (do-times
+ (apply * fx-dims)
+ (lambda (i)
+ (array-set! ac (Jx i)
+ i)))
+ a))
+ (else
+ (let* ((fx-dims (array-dimensions fx))
+ (y-dims (array-dimensions y))
+ (n-fx-dims (apply * fx-dims))
+ (n-y-dims (apply * y-dims))
+ (a (apply make-array *unspecified* (append fx-dims y-dims)))
+ (ac (array-contents a)))
+ (do-times
+ n-y-dims
+ (lambda (j)
+ (let ((Jx (Jx j)))
+ (do-times
+ n-fx-dims
+ (lambda (i)
+ (array-set! ac (Jx i)
+ (+ j (* n-y-dims i))))))))
+ a)))))))
+
+
+;; In the comment that follows:
+;;
+;; `n' is the number of arguments to `proc'.
+;; `generators is not a `Vec' but a `List' we only use the former to illustrate
+;; its length.
;; `X1', ..., `Xn' are the types of inputs and thus `Array's of some dimension.
;; `I' is the type of multi-indices indexing the output of `function'.
;; `J' is the type of multi-indices indexing the input array being differentiated.
+;; `|I|' (resp. `|J|') is the type of absolute indices of `I' (resp. `J').
;; `Array I' is the type of arrays indexed by multi-indices of `I'.
;; `[X]' means that `X' is boxed in an internal as when returned by
-;; `differentiable-wrapper' with the array being `X' and the function
-;; taking a multi-index of `J' to return an array of the same shape as `X'.
-;; (∷ (→ (Vec n (→ X1 ... Xn I J Number))
-;; (→ X1 ... Xn (Array I))
+;; `differentiable-wrapper' with the array being `X' and the promise that
+;; given a |J| we will get the change of `X' with a change of the
+;; the differentiated argument at multi-index `J'.
+;; (∷ (→ (Vec n (→ X1 ... Xn |I| |J| Number))
+;; (→ X1 ... Xn (Array I))*
;; [X1] ... [Xn]
-;; (Internal (Array I) (→ J (Array I)))))
-(define (differentiable-wrapper jacobian-generators function input . more)
- ;; NOTE: Both the jacobian generators and the function act on naked inputs
- ;; (numbers or arrays not inside an internal object). The generators
- ;; additionaly take indices -- one for each dimension of what we are
- ;; differentiating with respect to (the content of `(*grad*)'). A
- ;; generator returns a jacobian column expressing the change of the
- ;; function's output when changing `(*grad*)' along the given indices.
- ;; In cases where an argument isn't meant to be differentiable its
- ;; corresponding generator should be `#f'.
- (let* ((inputs (cons input more))
- (naked-inputs (map unbox-fwd inputs))
- (output (apply (if (procedure? function)
- function
- (car function))
- naked-inputs))
- (data (if (procedure? function)
- '()
- (map (lambda (f) (apply f naked-inputs))
- (cdr function)))))
+;; (Internal (Array I) (Promise |J| (Array |I|)))))
+;;
+;; (*) We extend this definition to allow `proc' to be a list of procedures
+;; the head of which is as described above and the remaining elements
+;; are procedures of the same arguments but returning values that are
+;; then fed as extra data to the generators.
+;;
+;; NOTE: In cases where an argument isn't meant to be differentiable its
+;; corresponding generator should be `#f'.
+(define (differentiable-wrapper generators proc* arg . more)
+ (let* ((args (cons arg more))
+ (proc (if (procedure? proc*)
+ proc*
+ (car proc*)))
+ (naked-args (map unwrap-fwd args))
+ (out (apply proc naked-args)))
(ifn (*grad*)
- output
- (make-internal
- output
- (lambda (js)
- (or
- (fold
- (lambda (jacobian-generator input prev)
- (ifn (internal? input)
- prev
- ((ifn prev identity (lambda (x) ((extend +) prev x)))
- (ifn (internal-jacobian input)
- (if (number? output)
- (apply jacobian-generator naked-inputs js data)
- (apply
- produce-typed-array
- (lambda is
- (apply jacobian-generator
- naked-inputs
- (append is js)
- data))
- *atype*
- (array-dimensions output)))
- (let ((b ((internal-jacobian input) js))
- (fwd (internal-forward input)))
- (cond
- ((and (number? output) (number? fwd))
- (* (apply jacobian-generator
- naked-inputs '() data)
- b))
- ((and (number? output) (array? fwd))
- (array-ref
- (contract-arrays
- (apply
- produce-typed-array
- (lambda ks
- (apply jacobian-generator
- naked-inputs ks data))
- *atype*
- (array-dimensions fwd))
- b (array-rank fwd))))
- ((and (array? output) (number? fwd))
- (apply
- produce-typed-array
- (lambda is
- ((extend *)
- (apply jacobian-generator
- naked-inputs is data)
- b))
- *atype*
- (array-dimensions output)))
- (else
- (apply
- produce-typed-array
- (lambda is
- (array-ref
- (contract-arrays
- (apply
- produce-typed-array
- (lambda ks
- (apply jacobian-generator
- naked-inputs
- (append is ks)
- data))
- *atype*
- (array-dimensions fwd))
- b (array-rank fwd))))
- *atype*
- (array-dimensions output)))))))))
- #f jacobian-generators inputs)
- (if (number? output)
- 0
- (apply make-typed-array *atype* 0 (array-dimensions output)))))))))
+ out
+ (let* ((data (if (procedure? proc*)
+ '()
+ (map (lambda (g)
+ (apply g naked-args))
+ (cdr proc*))))
+ (n-out-dims (apply * (dims-of out)))
+ (buf (make-array *unspecified* n-out-dims)))
+ (make-internal
+ out
+ (fold
+ (lambda (generator arg prev)
+ (if (or (not (internal? arg))
+ (eq? 'zero (internal-jacobian arg)))
+ prev
+ (let ((Jx (internal-jacobian arg))
+ (n-fwd-dims (apply * (dims-of (unwrap-fwd arg)))))
+ (if (eq? Jx 'input)
+ (if (eq? prev 'zero)
+ (delay
+ (mov buf n-out-dims
+ generator naked-args data (*j*)))
+ (delay
+ (add (force prev) n-out-dims
+ generator naked-args data (*j*))))
+ (if (eq? prev 'zero)
+ (delay
+ (movc buf n-out-dims (force Jx) n-fwd-dims
+ generator naked-args data))
+ (delay
+ (addc (force prev) n-out-dims
+ (force Jx) n-fwd-dims
+ generator naked-args data)))))))
+ 'zero generators args))))))
(define (ewise1 f)
- (lambda (xs indices)
+ (lambda (xs i j)
(let ((x (car xs)))
(if (number? x)
(f x)
- (receive (is js) (split-at indices (array-rank x))
- (if (equal? is js)
- (f (apply array-ref x is))
- 0))))))
+ (ifn (= i j)
+ 0
+ (f (array-ref (array-contents x)
+ j)))))))
(define (ewise2 proc axis)
- (lambda (xs indices)
+ (lambda (xs i j)
(let ((x (car xs))
(y (cadr xs)))
(cond
(proc x y))
((and (number? x) (array? y))
(if (= axis 0)
- (proc x (apply array-ref y indices))
- (receive (is js) (split-at indices (array-rank y))
- (ifn (equal? is js)
- 0
- (proc x (apply array-ref y is))))))
+ (proc x (array-ref (array-contents y)
+ i))
+ (ifn (= i j)
+ 0
+ (proc x (array-ref (array-contents y)
+ j)))))
((and (array? x) (number? y))
(if (= axis 1)
- (proc (apply array-ref x indices) y)
- (receive (is js) (split-at indices (array-rank x))
- (ifn (equal? is js)
- 0
- (proc (apply array-ref x is)
- y)))))
+ (proc (array-ref (array-contents x)
+ i)
+ y)
+ (ifn (= i j)
+ 0
+ (proc (array-ref (array-contents x)
+ j)
+ y))))
(else
- (receive (is js) (split-at indices (array-rank x))
- (ifn (equal? is js)
- 0
- (proc (apply array-ref x is)
- (apply array-ref y is)))))))))
+ (ifn (= i j)
+ 0
+ (proc (array-ref (array-contents x)
+ j)
+ (array-ref (array-contents y)
+ j))))))))
(define (i:identity x)
"Differentiable identity."
"Differentiable mean on arrays."
(differentiable-wrapper
(list
- (lambda (xs indices one-over-n)
+ (lambda (xs i j one-over-n)
one-over-n))
(let ((n 0))
- (list
- (lambda (x)
- (let ((sum 0))
- (array-for-each
- (lambda (x)
- (set! sum (+ sum x))
- (set! n (1+ n)))
- x)
- (/ sum n)))
- (lambda _ (/ n))))
+ (list
+ (lambda (x)
+ (let ((sum 0))
+ (array-for-each
+ (lambda (x)
+ (set! sum (+ sum x))
+ (set! n (1+ n)))
+ x)
+ (/ sum n)))
+ (lambda _ (/ n))))
x))
(define (i:array-ref x . indices)
(apply
differentiable-wrapper
(cons
- (lambda (xs js)
- (if (equal? js indices)
+ (lambda (xs i j abs-index)
+ (if (= j abs-index)
1
0))
- (list-tabulate (length indices) not))
- array-ref
+ (map not indices))
+ (list
+ array-ref
+ (lambda (x . indices)
+ (rel->abs indices (array-dimensions x))))
+ x indices))
+
+(define (i:array-cell-ref x . indices)
+ (apply
+ differentiable-wrapper
+ (cons
+ (lambda (xs i j abs-index n-rst-dims)
+ (receive (j-ref j-rst) (euclidean/ j n-rst-dims)
+ (if (and (= j-ref abs-index)
+ (= j-rst i))
+ 1
+ 0)))
+ (map not indices))
+ (list
+ array-cell-ref
+ (lambda (x . indices)
+ (rel->abs indices (take (array-dimensions x)
+ (length indices))))
+ (lambda (x . indices)
+ (apply * (drop (array-dimensions x)
+ (length indices)))))
x indices))
+(define (make-batch elem . more)
+ (let ((batch-size (1+ (length more))))
+ (apply
+ differentiable-wrapper
+ (list-tabulate
+ batch-size
+ (lambda (b)
+ (lambda (xs i j n-rest-dims)
+ (receive (i-batch i-rest) (euclidean/ i n-rest-dims)
+ (if (and (= i-batch b)
+ (= i-rest j))
+ 1
+ 0
+ )))))
+ (list
+ (lambda (elem . more)
+ (let ((a (apply make-typed-array *atype* *unspecified* batch-size
+ (dims-of elem))))
+ (for-each
+ (lambda (x b)
+ (array-cell-set! a x b))
+ (cons elem more)
+ (list-tabulate batch-size identity))
+ a))
+ (lambda (elem . more)
+ (apply * (dims-of elem))))
+ elem more)))
+
(define (maximum x)
"Differentiable maximum on arrays."
(differentiable-wrapper
(list
- (lambda (xs indices max-indices)
- (if (equal? indices max-indices)
+ (lambda (xs i j max-index)
+ (if (= j max-index)
1
0)))
(let ((max-index 'TBD))
(set! i (1+ i)))
x)
m))
- (lambda (x)
- (abs->rel max-index (array-dimensions x)))))
+ (lambda _ max-index)))
x))
(define (sum x)
sum))
x))
-(define (amap2 f x y)
- (define (unbox-with proc x)
- (ifn (internal? x)
- x
- (proc x)))
- (define (dims-of x)
- (if (number? x)
- '()
- (array-dimensions x)))
- (define (boxed-ref x i)
- (let ((xi (array-cell-ref (unbox-with internal-forward x)
- i)))
- (ifn (internal? x)
- xi
- (make-internal
- xi
- (if (internal-jacobian x)
- (lambda (js)
- (array-cell-ref
- ((internal-jacobian x)
- js)
- i))
- (lambda (js)
- (if (number? xi)
- (if (= i (car js))
- 1
- 0)
- (apply
- produce-typed-array
- (lambda indices
- (if (and (= i (car js))
- (equal? indices (cdr js)))
- 1
- 0))
- (array-type xi)
- (array-dimensions xi)))))))))
- (define (boxed-fi f i x y)
- (f (boxed-ref x i)
- (boxed-ref y i)))
- (if (internal? f)
- (error "unsupported application of `amap2' to internal object" f)
- (let ((bs (first (array-dimensions (unbox-with internal-forward x))))
- (f0 (boxed-fi f 0 x y)))
- (ifn (internal? f0)
- (let ((fwd (apply make-typed-array
- ;; TODO: use the correct type based on f0.
- *atype* *unspecified* bs (dims-of f0))))
- (for-indices-in-range
- (lambda (i)
- (array-cell-set! fwd (boxed-fi f i x y) i))
- '(0) (list bs))
- fwd)
- (let ((jac (make-array *unspecified* bs))
- (fwd (apply make-typed-array
- ;; TODO: use the correct type based on f0.
- *atype* *unspecified*
- bs (dims-of (internal-forward f0)))))
- (for-indices-in-range
- (lambda (i)
- (let* ((fi (boxed-fi f i x y))
- (Jfi (internal-jacobian fi)))
- (array-set! jac (internal-jacobian fi) i)
- (array-cell-set! fwd (internal-forward fi) i)))
- '(0) (list bs))
- (make-internal
- fwd
- (lambda (js)
- (let ((a (apply make-typed-array
- ;; TODO: use the correct type based on f0.
- *atype* *unspecified*
- bs (dims-of (internal-forward f0)))))
- (for-indices-in-range
- (lambda (batch)
- (array-cell-set!
- a ((array-ref jac batch) js)
- batch))
- '(0) (list bs))
- a))))))))
-
(define (adot x y n)
(differentiable-wrapper
(list
- (lambda (xs indices free-rank-x free-rank-y)
- (let* ((y (cadr xs))
- (n (caddr xs))
- (out (take indices (+ free-rank-x free-rank-y)))
- (in (drop indices (+ free-rank-x free-rank-y)))
- (is-free1 (take out free-rank-x))
- (js-free (drop out free-rank-x))
- (is-free2 (take in free-rank-x))
- (is-bound (drop in free-rank-x)))
- (ifn (equal? is-free1 is-free2)
- 0
- (apply array-ref y (append is-bound js-free)))))
- (lambda (xs indices free-rank-x free-rank-y)
- (let* ((x (car xs))
- (n (caddr xs))
- (out (take indices (+ free-rank-x free-rank-y)))
- (in (drop indices (+ free-rank-x free-rank-y)))
- (is-free (take out free-rank-x))
- (js-free1 (drop out free-rank-x))
- (js-free2 (drop in n))
- (js-bound (take in n)))
- (ifn (equal? js-free1 js-free2)
- 0
- (apply array-ref x (append is-free js-bound)))))
+ (lambda (xs i j n-free-dims-y n-bound-dims)
+ (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+ (receive (j-free j-bound) (euclidean/ j n-bound-dims)
+ (ifn (= i-x j-free)
+ 0
+ (array-ref (array-contents (cadr xs))
+ (+ i-y (* n-free-dims-y j-bound)))))))
+ (lambda (xs i j n-free-dims-y n-bound-dims)
+ (receive (i-x i-y) (euclidean/ i n-free-dims-y)
+ (receive (j-bound j-free) (euclidean/ j n-free-dims-y)
+ (ifn (= i-y j-free)
+ 0
+ (array-ref (array-contents (car xs))
+ (+ j-bound (* n-bound-dims i-x)))))))
#f)
(list
contract-arrays
- (lambda (x y n) (- (array-rank x) n))
- (lambda (x y n) (- (array-rank y) n)))
+ (lambda (x y n)
+ (apply * (drop (array-dimensions y)
+ n)))
+ (lambda (x y n)
+ (apply * (take (array-dimensions y)
+ n))))
x y n))
+
+(define (amap2 f x y)
+ (apply make-batch
+ (list-tabulate (car (dims-of (unwrap-fwd x)))
+ (lambda (b)
+ (f (i:array-cell-ref x b)
+ (i:array-cell-ref y b))))))