MNISTNet¶
Usage
use MNISTNet;
or
import MNISTNet;
- config param layerDebug = false¶
- type dtype = real(32)¶
- class CNN : Module(?)¶
- var conv1 : owned(Conv2D(eltType))¶
- var conv2 : owned(Conv2D(eltType))¶
- var dropout1 : owned(Dropout(eltType))¶
- var dropout2 : owned(Dropout(eltType))¶
- var flatten : owned(Flatten(eltType))¶
- var fc1 : owned(Linear(eltType))¶
- var fc2 : owned(Linear(eltType))¶
- proc init(type eltType = dtype)¶
- override proc forward(input: Tensor(eltType)) : Tensor(eltType)¶
- config const diag = false¶
- var cnn = new CNN(dtype)¶
- var model = Network.loadModel(specFile = "../scripts/models/cnn/specification.json", weightsFolder = "../scripts/models/cnn/", dtype = dtype)¶
- config const testImgSize = 28¶
- var img = Tensor.load("data/datasets/mnist/image_idx_0_7_7.chdata") : dtype¶
- const modelPath = "data/models/mnist_cnn/"¶
- var output = cnn(img)¶
- config const imageCount = 0¶
- var images = forall i in 0..<imageCount do Tensor.load("data/datasets/mnist/image_idx_" + i : string + ".chdata") : dtype¶
- var preds : [images.domain] int¶
- config const numTimes = 1¶
- config const printResults = false¶
- var cnn2 = new Sequential(real, (new Conv2D(real, channels = 1, features = 32, kernel = 3, stride = 1)?, new Conv2D(real, channels = 32, features = 64, kernel = 3, stride = 1)?, new Dropout(real, 0.25)?, new Dropout(real, 0.5)?, new Flatten(real)?, new Linear(real, 9216, 128)?, new Linear(real, 128, 10)?))¶