]> git.vouivredigital.com Git - vouivre.git/commitdiff
Download and load the MNIST dataset
authoradmin <admin@vouivredigital.com>
Sat, 25 Nov 2023 08:22:06 +0000 (17:22 +0900)
committeradmin <admin@vouivredigital.com>
Sat, 25 Nov 2023 08:22:06 +0000 (17:22 +0900)
mnist.scm [new file with mode: 0644]

diff --git a/mnist.scm b/mnist.scm
new file mode 100644 (file)
index 0000000..eabbe5e
--- /dev/null
+++ b/mnist.scm
@@ -0,0 +1,95 @@
+;;;; Copyright (C) 2023 Vouivre Digital Corporation
+;;;;
+;;;; This file is part of Vouivre.
+;;;;
+;;;; Vouivre is free software: you can redistribute it and/or
+;;;; modify it under the terms of the GNU General Public
+;;;; License as published by the Free Software Foundation, either
+;;;; version 3 of the License, or (at your option) any later version.
+;;;;
+;;;; Vouivre is distributed in the hope that it will be useful,
+;;;; but WITHOUT ANY WARRANTY; without even the implied warranty of
+;;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+;;;; General Public License for more details.
+;;;;
+;;;; You should have received a copy of the GNU General Public
+;;;; License along with Vouivre. If not, see <https://www.gnu.org/licenses/>.
+
+(define-module (vouivre mnist)
+  #:use-module (guix build download)
+  #:use-module (guix build utils)
+  #:use-module (ice-9 binary-ports)
+  #:use-module (rnrs bytevectors)
+  #:use-module (srfi srfi-1)
+  #:use-module (vouivre misc)
+  #:use-module (web uri)
+  #:export (load-mnist))
+
+;; NOTE: The directory and url must not include any trailing '/' character.
+(define directory "mnist")
+(define url "http://yann.lecun.com/exdb/mnist")
+(define trn-imgs-fname "train-images-idx3-ubyte")
+(define trn-lbls-fname "train-labels-idx1-ubyte")
+
+(define (exists? fname)
+  "Return `#t' if the file with the given name exists and `#f' otherwise."
+  (catch 'system-error
+    (lambda ()
+      (with-input-from-file fname
+       (lambda () #t)
+       #:binary #t))
+    (lambda _ #f)))
+
+(define (load-mnist nb-items download?)
+  "Return the given number of data points from the MNIST dataset downloading it
+if needed and requested in the ./mnist directory.
+
+The data is a cons cell containing an array (nb-items, height, width) of
+training images and an array (nb-items) of corresponding labels."
+  (define (read-uint bytes)
+    (bytevector-uint-ref (get-bytevector-n (current-input-port)
+                                          bytes)
+                        0
+                        (endianness big)
+                        bytes))  
+  (apply
+   cons
+   (map
+    (lambda (base-name magic rank)
+      (let ((fname (string-append directory "/" base-name)))
+       (let redo ((download? download?))
+         (if (exists? fname)
+             (with-input-from-file fname
+               (lambda ()
+                 (when (not (= magic (read-uint 4)))
+                   (error "Unsupported file magic number."))
+                 (let* ((n (min nb-items (read-uint 4)))
+                        (dims (list-tabulate rank (lambda (x) (read-uint 4))))
+                        (n-dims (apply * n dims))
+                        (a (apply make-typed-array 'u8 0 n dims))
+                        (ac (array-contents a)))
+                   (let lp ((i 0))
+                     (if (= i n-dims)
+                         a
+                         (begin
+                           (array-set! ac (read-uint 1) i)
+                           (lp (1+ i)))))))
+               #:binary #t)
+             (ifn download?
+                  (error (string-append "The MNIST dataset doesn't exist. If you tried with `download?' to `#t' already, to no avail, you can download the files manually from " url ", and extract them to a \"mnist\" directory at the root of the project. You can also file a bug report."))
+                  (let ((gzname (string-append fname ".gz")))
+                    (invoke "mkdir" "-p" directory)
+                    (call-with-output-file gzname
+                      (lambda (port)
+                        (put-bytevector
+                         port
+                         (get-bytevector-all
+                          (http-fetch
+                           (string->uri
+                            (string-append url "/" base-name ".gz"))))))
+                      #:binary #t)
+                    (invoke "gunzip" gzname)
+                    (redo #f)))))))
+    (list trn-imgs-fname trn-lbls-fname)
+    (list 2051 2049)
+    (list 2 0))))