MNISTNet

Usage

use MNISTNet;

or

import MNISTNet;
config param layerDebug = false
type dtype = real(32)
class CNN : Module(?)
var conv1 : owned(Conv2D(eltType))
var conv2 : owned(Conv2D(eltType))
var dropout1 : owned(Dropout(eltType))
var dropout2 : owned(Dropout(eltType))
var flatten : owned(Flatten(eltType))
var fc1 : owned(Linear(eltType))
var fc2 : owned(Linear(eltType))
proc init(type eltType = dtype)
override proc forward(input: Tensor(eltType)) : Tensor(eltType)
config const diag = false
var cnn = new CNN(dtype)
var model = Network.loadModel(specFile = "../scripts/models/cnn/specification.json", weightsFolder = "../scripts/models/cnn/", dtype = dtype)
config const testImgSize = 28
var img = Tensor.load("data/datasets/mnist/image_idx_0_7_7.chdata") : dtype
const modelPath = "data/models/mnist_cnn/"
var output = cnn(img)
config const imageCount = 0
var images = forall i in 0..<imageCount do Tensor.load("data/datasets/mnist/image_idx_" + i : string + ".chdata") : dtype
var preds : [images.domain] int
config const numTimes = 1
config const printResults = false
var cnn2 = new Sequential(real, (new Conv2D(real, channels = 1, features = 32, kernel = 3, stride = 1)?, new Conv2D(real, channels = 32, features = 64, kernel = 3, stride = 1)?, new Dropout(real, 0.25)?, new Dropout(real, 0.5)?, new Flatten(real)?, new Linear(real, 9216, 128)?, new Linear(real, 128, 10)?))