library(tidyverse)
library(data.table)
library(keras)
library(caret)
library(DT)
library(caretEnsemble)
library(tictoc)
<- fread("data/train.csv")
train_data set.seed(100)
= nrow(train_data)
N <- sample(N, 5000)
sample_one <- train_data[sample_one]
train_data <- fread("data/test.csv") test_data
MNIST Digits Recognition
Read data and Load libraries
Frequency of digits
ggplot(train_data, aes(x = factor(label))) +
geom_bar()
Randomly sample 12 digits
# image coordinates
<- data.frame(x = expand.grid(1:28, 28:1)[,1],
xy_axis y = expand.grid(1:28, 28:1)[,2])
# get 12 images
set.seed(100)
<- train_data[sample(1:.N, 12), -1] %>% as.matrix()
sample_10
datatable(sample_10,
options = list(scrollX = TRUE))
<- t(sample_10)
sample_10
<- cbind(xy_axis, sample_10 )
plot_data
setDT(plot_data, keep.rownames = "pixel")
# Observe the first records
head(plot_data) %>% datatable()
Plot 12 digits
<- melt(plot_data, id.vars = c("pixel", "x", "y"))
plot_data_m
# Plot the image using ggplot()
ggplot(plot_data_m, aes(x, y, fill = value)) +
geom_raster()+
facet_wrap(~variable)+
scale_fill_gradient(low = "white",
high = "black", guide = "none")+
theme(axis.line = element_blank(),
axis.text = element_blank(),
axis.ticks = element_blank(),
axis.title = element_blank(),
panel.background = element_blank(),
panel.border = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
plot.background = element_blank())
Prepare data for model fitting
- Decided to have a self test set
= nrow(train_data)
N <- sample(N, size = round(0.75 *N ))
sample_train <- train_data[-sample_train]
test_own <- train_data[sample_train, ]
train_data2 <-to_categorical(train_data2$label, 10)
train_y
<- train_data2[, -1]
train_x #convert to matrix
<- train_x %>%
train_x as.matrix()
<- train_x/255 train_x
Construct model layers
<- keras_model_sequential() %>%
model layer_dense(units = 784, activation = 'relu', input_shape = c(784)) %>%
layer_batch_normalization() %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 784, activation = 'relu') %>%
layer_batch_normalization() %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 392, activation = 'relu') %>%
layer_batch_normalization() %>%
layer_dropout(rate = 0.2) %>%
layer_dense(units = 10, activation = 'softmax')
# Comp
Compile model
# Compile the model
%>% compile(
model loss = 'categorical_crossentropy',
optimizer = optimizer_adam(lr = 0.001), # Start with a default learning rate
metrics = c('accuracy')
)
Fit model
# Implement a learning rate scheduler
<- function(epoch, lr) {
lr_schedule if (epoch < 10) {
return(lr)
else {
} return(lr * exp(-0.1))
}
}
# Add the callback for the learning rate scheduler
<- list(callback_learning_rate_scheduler(schedule = lr_schedule))
callbacks_list
# Train the model
<- model %>% fit(
hist
train_x, train_y, epochs = 30,
batch_size = 128,
validation_split = 0.2,
callbacks = callbacks_list
)
Epoch 1/30
24/24 - 1s - loss: 0.7792 - accuracy: 0.7577 - val_loss: 1.2017 - val_accuracy: 0.6920 - lr: 0.0010 - 1s/epoch - 54ms/step
Epoch 2/30
24/24 - 0s - loss: 0.2316 - accuracy: 0.9210 - val_loss: 0.9687 - val_accuracy: 0.7400 - lr: 0.0010 - 269ms/epoch - 11ms/step
Epoch 3/30
24/24 - 0s - loss: 0.1203 - accuracy: 0.9627 - val_loss: 0.7605 - val_accuracy: 0.8307 - lr: 0.0010 - 260ms/epoch - 11ms/step
Epoch 4/30
24/24 - 0s - loss: 0.0769 - accuracy: 0.9747 - val_loss: 0.6542 - val_accuracy: 0.8267 - lr: 0.0010 - 259ms/epoch - 11ms/step
Epoch 5/30
24/24 - 0s - loss: 0.0651 - accuracy: 0.9773 - val_loss: 0.6276 - val_accuracy: 0.8093 - lr: 0.0010 - 254ms/epoch - 11ms/step
Epoch 6/30
24/24 - 0s - loss: 0.0527 - accuracy: 0.9837 - val_loss: 0.4751 - val_accuracy: 0.8573 - lr: 0.0010 - 255ms/epoch - 11ms/step
Epoch 7/30
24/24 - 0s - loss: 0.0388 - accuracy: 0.9910 - val_loss: 0.3972 - val_accuracy: 0.8893 - lr: 0.0010 - 255ms/epoch - 11ms/step
Epoch 8/30
24/24 - 0s - loss: 0.0278 - accuracy: 0.9933 - val_loss: 0.3200 - val_accuracy: 0.9080 - lr: 0.0010 - 254ms/epoch - 11ms/step
Epoch 9/30
24/24 - 0s - loss: 0.0250 - accuracy: 0.9937 - val_loss: 0.3428 - val_accuracy: 0.9000 - lr: 0.0010 - 251ms/epoch - 10ms/step
Epoch 10/30
24/24 - 0s - loss: 0.0252 - accuracy: 0.9923 - val_loss: 0.2699 - val_accuracy: 0.9200 - lr: 0.0010 - 252ms/epoch - 11ms/step
Epoch 11/30
24/24 - 0s - loss: 0.0148 - accuracy: 0.9970 - val_loss: 0.2602 - val_accuracy: 0.9267 - lr: 9.0484e-04 - 255ms/epoch - 11ms/step
Epoch 12/30
24/24 - 0s - loss: 0.0146 - accuracy: 0.9973 - val_loss: 0.2578 - val_accuracy: 0.9267 - lr: 8.1873e-04 - 254ms/epoch - 11ms/step
Epoch 13/30
24/24 - 0s - loss: 0.0131 - accuracy: 0.9970 - val_loss: 0.2626 - val_accuracy: 0.9320 - lr: 7.4082e-04 - 255ms/epoch - 11ms/step
Epoch 14/30
24/24 - 0s - loss: 0.0143 - accuracy: 0.9973 - val_loss: 0.2580 - val_accuracy: 0.9320 - lr: 6.7032e-04 - 257ms/epoch - 11ms/step
Epoch 15/30
24/24 - 0s - loss: 0.0128 - accuracy: 0.9947 - val_loss: 0.2379 - val_accuracy: 0.9360 - lr: 6.0653e-04 - 256ms/epoch - 11ms/step
Epoch 16/30
24/24 - 0s - loss: 0.0111 - accuracy: 0.9973 - val_loss: 0.2523 - val_accuracy: 0.9373 - lr: 5.4881e-04 - 253ms/epoch - 11ms/step
Epoch 17/30
24/24 - 0s - loss: 0.0087 - accuracy: 0.9977 - val_loss: 0.2548 - val_accuracy: 0.9333 - lr: 4.9659e-04 - 257ms/epoch - 11ms/step
Epoch 18/30
24/24 - 0s - loss: 0.0104 - accuracy: 0.9977 - val_loss: 0.2530 - val_accuracy: 0.9387 - lr: 4.4933e-04 - 257ms/epoch - 11ms/step
Epoch 19/30
24/24 - 0s - loss: 0.0052 - accuracy: 1.0000 - val_loss: 0.2474 - val_accuracy: 0.9427 - lr: 4.0657e-04 - 255ms/epoch - 11ms/step
Epoch 20/30
24/24 - 0s - loss: 0.0046 - accuracy: 0.9993 - val_loss: 0.2475 - val_accuracy: 0.9387 - lr: 3.6788e-04 - 256ms/epoch - 11ms/step
Epoch 21/30
24/24 - 0s - loss: 0.0050 - accuracy: 0.9993 - val_loss: 0.2397 - val_accuracy: 0.9440 - lr: 3.3287e-04 - 256ms/epoch - 11ms/step
Epoch 22/30
24/24 - 0s - loss: 0.0067 - accuracy: 0.9990 - val_loss: 0.2428 - val_accuracy: 0.9480 - lr: 3.0119e-04 - 259ms/epoch - 11ms/step
Epoch 23/30
24/24 - 0s - loss: 0.0043 - accuracy: 0.9993 - val_loss: 0.2456 - val_accuracy: 0.9467 - lr: 2.7253e-04 - 255ms/epoch - 11ms/step
Epoch 24/30
24/24 - 0s - loss: 0.0042 - accuracy: 0.9990 - val_loss: 0.2488 - val_accuracy: 0.9493 - lr: 2.4660e-04 - 259ms/epoch - 11ms/step
Epoch 25/30
24/24 - 0s - loss: 0.0030 - accuracy: 1.0000 - val_loss: 0.2514 - val_accuracy: 0.9480 - lr: 2.2313e-04 - 252ms/epoch - 10ms/step
Epoch 26/30
24/24 - 0s - loss: 0.0038 - accuracy: 0.9997 - val_loss: 0.2507 - val_accuracy: 0.9467 - lr: 2.0190e-04 - 256ms/epoch - 11ms/step
Epoch 27/30
24/24 - 0s - loss: 0.0026 - accuracy: 1.0000 - val_loss: 0.2489 - val_accuracy: 0.9480 - lr: 1.8268e-04 - 253ms/epoch - 11ms/step
Epoch 28/30
24/24 - 0s - loss: 0.0038 - accuracy: 0.9990 - val_loss: 0.2530 - val_accuracy: 0.9467 - lr: 1.6530e-04 - 259ms/epoch - 11ms/step
Epoch 29/30
24/24 - 0s - loss: 0.0029 - accuracy: 0.9997 - val_loss: 0.2565 - val_accuracy: 0.9467 - lr: 1.4957e-04 - 257ms/epoch - 11ms/step
Epoch 30/30
24/24 - 0s - loss: 0.0036 - accuracy: 0.9997 - val_loss: 0.2558 - val_accuracy: 0.9467 - lr: 1.3534e-04 - 253ms/epoch - 11ms/step
plot(hist)
Own test
<- test_own[, -1] %>% as.matrix()/255
test_own_x
<- model %>% predict(test_own_x) test_own_pred
40/40 - 0s - 146ms/epoch - 4ms/step
<- apply(test_own_pred, 1, which.max) - 1
test_own_pred
confusionMatrix(data = factor(test_own_pred), reference = factor(test_own$label))
Confusion Matrix and Statistics
Reference
Prediction 0 1 2 3 4 5 6 7 8 9
0 122 0 1 0 1 0 0 1 0 1
1 0 133 0 0 0 0 0 0 0 0
2 0 2 104 1 1 0 0 0 0 0
3 0 0 3 119 0 4 0 0 4 5
4 0 1 0 2 118 0 0 0 0 4
5 0 0 0 2 0 95 1 0 0 0
6 0 0 0 0 0 1 147 0 1 0
7 1 0 3 0 0 1 0 136 0 2
8 2 0 1 3 0 1 1 0 111 1
9 0 0 0 0 3 0 0 4 2 104
Overall Statistics
Accuracy : 0.9512
95% CI : (0.9378, 0.9625)
No Information Rate : 0.1192
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.9457
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
Sensitivity 0.9760 0.9779 0.9286 0.9370 0.9593 0.9314
Specificity 0.9964 1.0000 0.9965 0.9858 0.9938 0.9974
Pos Pred Value 0.9683 1.0000 0.9630 0.8815 0.9440 0.9694
Neg Pred Value 0.9973 0.9973 0.9930 0.9928 0.9956 0.9939
Prevalence 0.1000 0.1088 0.0896 0.1016 0.0984 0.0816
Detection Rate 0.0976 0.1064 0.0832 0.0952 0.0944 0.0760
Detection Prevalence 0.1008 0.1064 0.0864 0.1080 0.1000 0.0784
Balanced Accuracy 0.9862 0.9890 0.9625 0.9614 0.9766 0.9644
Class: 6 Class: 7 Class: 8 Class: 9
Sensitivity 0.9866 0.9645 0.9407 0.8889
Specificity 0.9982 0.9937 0.9920 0.9921
Pos Pred Value 0.9866 0.9510 0.9250 0.9204
Neg Pred Value 0.9982 0.9955 0.9938 0.9886
Prevalence 0.1192 0.1128 0.0944 0.0936
Detection Rate 0.1176 0.1088 0.0888 0.0832
Detection Prevalence 0.1192 0.1144 0.0960 0.0904
Balanced Accuracy 0.9924 0.9791 0.9664 0.9405