(*atype*
adot
amap2
+ contract-arrays
differentiable-wrapper
dot
extend
maximum
mean
rank-of
- sum)
+ sum
+ )
#:replace
((i:sqrt . sqrt)
(i:exp . exp)
n))
(drop dims-b n)))))
+(define(do-times n proc)
+ (let rec ((i 0))
+ (unless (= i n)
+ (proc i)
+ (rec (1+ i)))))
+
(define (contract-arrays a b n)
- (let* ((dims (contracted-dims a b n))
- (dims-b (array-dimensions b))
- (nb-dims-a (array-rank a))
- (nb-dims-b (array-rank b))
- (nb-fix-dims-a (- nb-dims-a n))
- (r (apply make-typed-array *atype* *unspecified* dims)))
- (for-indices-in-range
- (lambda r-indices
- (apply
- array-set!
- r
- (let ((s 0))
- (for-indices-in-range
- (lambda free-indices
- (set! s (+ s (* (apply
- array-ref
- a
- (append (take r-indices nb-fix-dims-a)
- free-indices))
- (apply
- array-ref
- b
- (append free-indices
- (drop r-indices nb-fix-dims-a)))))))
- (list-zeros n)
- (take dims-b n))
- s)
- r-indices))
- (list-zeros (length dims))
- dims)
+ (let* ((dims-a (array-dimensions a))
+ (dims-b (array-dimensions b))
+ (free-dims-a (take dims-a (- (array-rank a) n)))
+ (free-dims-b (drop dims-b n))
+ (bound-dims (take dims-b n))
+ (n-free-dims-a (apply * free-dims-a))
+ (n-free-dims-b (apply * free-dims-b))
+ (n-bound-dims (apply * bound-dims))
+ (s 0)
+ (r (apply make-typed-array *atype* *unspecified* (append free-dims-a
+ free-dims-b)))
+ (ac (array-contents a))
+ (bc (array-contents b))
+ (rc (array-contents r)))
+ (do-times
+ n-free-dims-a
+ (lambda (i)
+ (let ((i-k (* n-bound-dims i))
+ (i-j (* n-free-dims-b i)))
+ (do-times
+ n-free-dims-b
+ (lambda (j)
+ (set! s 0)
+ (do-times
+ n-bound-dims
+ (lambda (k)
+ (set! s (+ s (* (array-ref ac (+ i-k k))
+ (array-ref bc (+ (* n-free-dims-b k) j)))))))
+ (array-set! rc s (+ i-j j)))))))
r))
;;;; utilities that work on both numbers and arrays
(forward internal-forward)
(jacobian internal-jacobian))
-;; (define *atype* 'f64)
+;;(define *atype* 'f32)
(define *atype* #t)
(define *grad* (make-parameter #f))