Browse Source

first commit!

killua 5 years ago
commit
679483b9e6
5 changed files with 145 additions and 0 deletions
  1. 28 0
      Array
  2. 98 0
      CNN.py
  3. 19 0
      Converter.py
  4. BIN
      mnistCNN.mlmodel
  5. BIN
      mnist_cnn_keras.h5

+ 28 - 0
Array

@@ -0,0 +1,28 @@
+[[ 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3, 18,  18,  18, 126, 136, 175,  26, 166, 255, 247, 127,  0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,  30,  36,  94, 154, 170, 253, 253, 253, 253, 253, 225, 172, 253, 242, 195, 64,  0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,  49, 238, 253, 253, 253, 253, 253, 253, 253, 253, 251,  93,  82,  82,  56,  39, 0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,  18, 219, 253, 253, 253, 253, 253, 198, 182, 247, 241,   0,   0,   0,   0,   0, 0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,  80, 156, 107, 253, 253, 205,  11,   0,  43, 154,   0,   0,   0,   0,   0, 0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,  14,   1, 154, 253, 90,   0,   0,   0,   0,    0,   0,   0,   0,   0, 0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 139, 253, 190,  2,   0,   0,   0,    0,   0,   0,   0,   0, 0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  11, 190, 253,  70,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,  0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  35, 241, 225, 160, 108,   1,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 81, 240, 253, 253, 119,  25,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,  45, 186, 253, 253, 150,  27,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,  16,  93, 252, 253, 187,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0, 249, 253, 249,  64,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,  46, 130, 183, 253, 253, 207,   2,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  39, 148, 229, 253, 253, 253, 250, 182,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  24, 114, 221, 253, 253, 253, 253, 201,  78,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,  23,  66, 213, 253, 253, 253, 253, 198,  81,   2,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,  18, 171, 219, 253, 253, 253, 253, 195,  80,   9,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,  55, 172, 226, 253, 253, 253, 253, 244, 133, 11,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0, 136, 253, 253, 253, 212, 135, 132,  16,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0],
+[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0]]S

+ 98 - 0
CNN.py

@@ -0,0 +1,98 @@
+# coding=utf8
+
+import numpy as np
+from keras.datasets import mnist
+from keras.utils import np_utils
+from keras.models import Sequential
+from keras.layers import Dense, Activation, Convolution2D, MaxPooling2D, Flatten
+from keras.optimizers import Adam
+from keras import models
+
+np.random.seed(1337)  # for reproducibility
+
+# 加载数据集,下载的数据保存在'~/.keras/datasets/'
+(X_train, y_train), (X_test, y_test) = mnist.load_data()
+# print("X_train.shape", X_train.shape, "y_train.shape", y_train.shape)
+# print("X_train", X_train[0], "y_train", y_train[0])
+
+# 处理数据(对数据进行归一化)
+X_train = X_train.reshape(-1, 28, 28, 1) / 255.  # 为了让激励函数更加有效
+X_test = X_test.reshape(-1, 28, 28, 1) / 255.  #
+y_train = np_utils.to_categorical(y_train, num_classes=10)
+y_test = np_utils.to_categorical(y_test, num_classes=10)
+# print("y_train.shape", y_train.shape, "y_train", y_train[0])
+
+# 建立一个model
+model = Sequential()
+
+# 添加第一层卷积层
+model.add(Convolution2D(
+    batch_input_shape=(None, 28, 28, 1),  # 输入源的shape
+    filters=32,  # 过滤器的数量(卷积核)
+    kernel_size=5,  # 过滤器的大小
+    strides=1,  # 过滤器移动的步长
+    padding='same',  # Padding的方法
+))
+print(model.output)
+
+# 给model添加一个激励函数
+model.add(Activation('relu'))
+
+# 添加一个(max pooling) output shape (32, 14, 14)
+model.add(MaxPooling2D(
+    pool_size=2,  # 池化层的大小
+    strides=2,  # 池化移动的步长
+    padding='same',  # Padding的方法
+))
+print(model.output)
+
+# 添加第二层卷积层 output shape (64, 14, 14)
+model.add(Convolution2D(
+    filters=64,
+    kernel_size=5,
+    strides=1,
+    padding='same'
+))
+print(model.output)
+
+# 给model添加一个激励函数
+model.add(Activation('relu'))
+
+# # 添加一个(max pooling) output shape (64, 7, 7)
+model.add(MaxPooling2D(
+    pool_size=2,
+    strides=2,
+    padding='same',
+))
+print(model.output)
+
+# 设置第一个全连接层
+model.add(Flatten())  # Flatten层用来将输入“压平”
+
+print(model.output)
+
+model.add(Dense(1024))  # 全连接层
+model.add(Activation('relu'))  # 给model添加一个激励函数
+
+# 设置第二个全连接层
+model.add(Dense(10))
+model.add(Activation('softmax'))
+
+# 定义一个优化器,设置一个学习效率
+adam = Adam(lr=1e-4)
+
+# 编译模型指定优化器,损失函数及评价指标
+model.compile(optimizer=adam,
+              loss='categorical_crossentropy',
+              metrics=['accuracy'])
+
+print('Training ------------')
+model.fit(X_train, y_train, epochs=5, batch_size=64)
+
+print('\nTesting ------------')
+loss, accuracy = model.evaluate(X_test, y_test)
+
+print('\ntest loss: ', loss)
+print('\ntest accuracy: ', accuracy)
+
+models.save_model(model, 'mnist_cnn_keras.h5')

+ 19 - 0
Converter.py

@@ -0,0 +1,19 @@
+import coremltools
+
+output_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
+scale = 1 / 255.
+coreml_model = coremltools.converters.keras.convert('./mnist_cnn_keras.h5',
+                                                    input_names='image',
+                                                    image_input_names='image',
+                                                    output_names='output',
+                                                    class_labels=output_labels,
+                                                    image_scale=scale)
+
+coreml_model.author = 'sky'
+coreml_model.license = 'MIT'
+coreml_model.short_description = 'Model to classify hand written digit'
+
+coreml_model.input_description['image'] = 'Grayscale image of hand written digit'
+coreml_model.output_description['output'] = 'Predicted digit'
+
+coreml_model.save('mnistCNN.mlmodel')

BIN
mnistCNN.mlmodel


BIN
mnist_cnn_keras.h5