Source : <"Coding chef's 3 minute deep running kerasu taste" "Kim Sung-jin">
코드
# DNN_CIFAR-10 / MNIST보다 복잡한 데이터의 처리(R,G,B) import numpy as np from keras import datasets from keras.utils import np_utils from keras import layers, models class DNN(models.Sequential): def __init__(self, Nin, Nh_l, Pd_l, Nout): super().__init__() # 첫 번째 은닉층 self.add(layers.Dense(Nh_l[0], activation='relu',input_shape=(Nin,), name='Hidden-1')) # Dropout 확률을 정한다. # Dropout : 랜덤으로 몇개의 노드를 비활성화 한다.(오버피팅 방지) self.add(layers.Dropout(Pd_l[0])) self.add(layers.Dense(Nh_l[1], activation='relu', name='Hidden-2')) self.add(layers.Dropout(Pd_l[1])) self.add(layers.Dense(Nout, activation='softmax')) self.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # 데이터 정리 def Data_func(): (X_train, y_train), (X_test, y_test) = datasets.cifar10.load_data() # 0~9 사이의 정수값을 10개의 원소를 가지는 벡터로 변환 # 1 -> 0100000000, 2 -> 0010000000, 9 -> 0000000001 Y_train = np_utils.to_categorical(y_train) Y_test = np_utils.to_categorical(y_test) L, W, H, C = X_train.shape X_train = X_train.reshape(-1, W * H * C) X_test = X_test.reshape(-1, W * H * C) X_train = X_train / 255.0 X_test = X_test / 255.0 return (X_train, Y_train), (X_test, Y_test) from ex_2_1_keras_ann import plot_acc, plot_loss import matplotlib.pyplot as plt # 모델 테스팅 def main(): Nh_l = [100, 50] Pd_l = [0.0, 0.0] number_of_class = 10 Nout = number_of_class (X_train, Y_train), (X_test, Y_test) = Data_func() model = DNN(X_train.shape[1], Nh_l, Pd_l, Nout) history = model.fit(X_train, Y_train, epochs=10, batch_size=100, validation_split=0.2) performace_test = model.evaluate(X_test, Y_test, batch_size=100) print('Test Loss and Accuracy ->', performace_test) plot_loss(history) plt.show() plot_acc(history) plt.show() if __name__ == '__main__': main()
실행 결과
'IT > 머신러닝' 카테고리의 다른 글
Ubuntu 18.04 LTS에서 Tensorflow Gpu 설치 ( CUDA10.0, cuDNN v7.3.1 ) (5) | 2018.10.08 |
---|---|
[Keras] CNN 기본예제(mnist) (0) | 2018.09.19 |
[Keras] DNN 기본 예제 (0) | 2018.08.31 |
[Keras] ANN 기본 예제 (0) | 2018.08.31 |
[Keras] 기본 예제 (0) | 2018.08.31 |