更新時(shí)間:2023-03-07 來(lái)源:黑馬程序員 瀏覽量:
在sklearn中,模型都是現(xiàn)成的。tf.Keras是一個(gè)神經(jīng)網(wǎng)絡(luò)庫(kù),我們需要根據(jù)數(shù)據(jù)和標(biāo)簽值構(gòu)建神經(jīng)網(wǎng)絡(luò)。神經(jīng)網(wǎng)絡(luò)可以發(fā)現(xiàn)特征與標(biāo)簽之間的復(fù)雜關(guān)系。神經(jīng)網(wǎng)絡(luò)是一個(gè)高度結(jié)構(gòu)化的圖,其中包含一個(gè)或多個(gè)隱藏層。每個(gè)隱藏層都包含一個(gè)或多個(gè)神經(jīng)元。神經(jīng)網(wǎng)絡(luò)有多種類別,該程序使用的是密集型神經(jīng)網(wǎng)絡(luò),也稱為全連接神經(jīng)網(wǎng)絡(luò):一個(gè)層中的神經(jīng)元將從上一層中的每個(gè)神經(jīng)元獲取輸入連接。例如,圖 2 顯示了一個(gè)密集型神經(jīng)網(wǎng)絡(luò),其中包含 1 個(gè)輸入層、2 個(gè)隱藏層以及 1 個(gè)輸出層,如下圖所示:
上圖 中的模型經(jīng)過訓(xùn)練并饋送未標(biāo)記的樣本時(shí),它會(huì)產(chǎn)生 3 個(gè)預(yù)測(cè)結(jié)果:相應(yīng)鳶尾花屬于指定品種的可能性。對(duì)于該示例,輸出預(yù)測(cè)結(jié)果的總和是 1.0。該預(yù)測(cè)結(jié)果分解如下:山鳶尾為 0.02,變色鳶尾為 0.95,維吉尼亞鳶尾為 0.03。這意味著該模型預(yù)測(cè)某個(gè)無(wú)標(biāo)簽鳶尾花樣本是變色鳶尾的概率為 95%。
TensorFlow tf.keras API 是創(chuàng)建模型和層的首選方式。通過該 API,您可以輕松地構(gòu)建模型并進(jìn)行實(shí)驗(yàn),而將所有部分連接在一起的復(fù)雜工作則由 Keras 處理。
tf.keras.Sequential 模型是層的線性堆疊。該模型的構(gòu)造函數(shù)會(huì)采用一系列層實(shí)例;在本示例中,采用的是 2 個(gè)密集層(分別包含 10 個(gè)節(jié)點(diǎn))以及 1 個(gè)輸出層(包含 3 個(gè)代表標(biāo)簽預(yù)測(cè)的節(jié)點(diǎn))。第一個(gè)層的 input_shape 參數(shù)對(duì)應(yīng)該數(shù)據(jù)集中的特征數(shù)量:
# 利用sequential方式構(gòu)建模型model = Sequential([ # 隱藏層1,激活函數(shù)是relu,輸入大小有input_shape指定 Dense(10, activation="relu", input_shape=(4,)), # 隱藏層2,激活函數(shù)是relu Dense(10, activation="relu"), # 輸出層 Dense(3,activation="softmax")])
通過model.summary可以查看模型的架構(gòu):
激活函數(shù)可決定層中每個(gè)節(jié)點(diǎn)的輸出形狀。這些非線性關(guān)系很重要,如果沒有它們,模型將等同于單個(gè)層。激活函數(shù)有很多,但隱藏層通常使用 ReLU。
隱藏層和神經(jīng)元的理想數(shù)量取決于問題和數(shù)據(jù)集。與機(jī)器學(xué)習(xí)的多個(gè)方面一樣,選擇最佳的神經(jīng)網(wǎng)絡(luò)形狀需要一定的知識(shí)水平和實(shí)驗(yàn)基礎(chǔ)。一般來(lái)說,增加隱藏層和神經(jīng)元的數(shù)量通常會(huì)產(chǎn)生更強(qiáng)大的模型,而這需要更多數(shù)據(jù)才能有效地進(jìn)行訓(xùn)練。
模型訓(xùn)練和預(yù)測(cè)
在訓(xùn)練和評(píng)估階段,我們都需要計(jì)算模型的損失。這樣可以衡量模型的預(yù)測(cè)結(jié)果與預(yù)期標(biāo)簽有多大偏差,也就是說,模型的效果有多差。我們希望盡可能減小或優(yōu)化這個(gè)值,所以我們?cè)O(shè)置優(yōu)化策略和損失函數(shù),以及模型精度的計(jì)算方法:
# 設(shè)置模型的相關(guān)參數(shù):優(yōu)化器,損失函數(shù)和評(píng)價(jià)指標(biāo)mode l.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])
接下來(lái)與在sklearn中相同,分別調(diào)用fit和predict方法進(jìn)行預(yù)測(cè)即可。
# 模型訓(xùn)練:epochs,訓(xùn)練樣本送入到網(wǎng)絡(luò)中的次數(shù),batch_size:每次訓(xùn)練的送入到網(wǎng)絡(luò)中的樣本個(gè)數(shù) model.fit(train_X, train_y_ohe, epochs=10, batch_size=1, verbose=1);
上述代碼完成的是:
1. 迭代每個(gè)epoch。通過一次數(shù)據(jù)集即為一個(gè)epoch。
2. 在一個(gè)epoch中,遍歷訓(xùn)練 Dataset 中的每個(gè)樣本,并獲取樣本的特征 (x) 和標(biāo)簽 (y)。
3. 根據(jù)樣本的特征進(jìn)行預(yù)測(cè),并比較預(yù)測(cè)結(jié)果和標(biāo)簽。衡量預(yù)測(cè)結(jié)果的不準(zhǔn)確性,并使用所得的值計(jì)算模型的損失和梯度。
4. 使用 optimizer 更新模型的變量。
5. 對(duì)每個(gè)epoch重復(fù)執(zhí)行以上步驟,直到模型訓(xùn)練完成。
訓(xùn)練過程展示如下:
Epoch 1/10 75/75 [==============================] - 0s 616us/step - loss: 0.0585 - accuracy: 0.9733 Epoch 2/10 75/75 [==============================] - 0s 535us/step - loss: 0.0541 - accuracy: 0.9867 Epoch 3/10 75/75 [==============================] - 0s 545us/step - loss: 0.0650 - accuracy: 0.9733 Epoch 4/10 75/75 [==============================] - 0s 542us/step - loss: 0.0865 - accuracy: 0.9733 Epoch 5/10 75/75 [==============================] - 0s 510us/step - loss: 0.0607 - accuracy: 0.9733 Epoch 6/10 75/75 [==============================] - 0s 659us/step - loss: 0.0735 - accuracy: 0.9733 Epoch 7/10 75/75 [==============================] - 0s 497us/step - loss: 0.0691 - accuracy: 0.9600 Epoch 8/10 75/75 [==============================] - 0s 497us/step - loss: 0.0724 - accuracy: 0.9733 Epoch 9/10 75/75 [==============================] - 0s 493us/step - loss: 0.0645 - accuracy: 0.9600 Epoch 10/10 75/75 [==============================] - 0s 482us/step - loss: 0.0660 - accuracy: 0.9867
與sklearn中不同,對(duì)訓(xùn)練好的模型進(jìn)行評(píng)估時(shí),與sklearn.score方法對(duì)應(yīng)的是tf.keras.evaluate()方法,返回的是損失函數(shù)和在compile模型時(shí)要求的指標(biāo):
# 計(jì)算模型的損失和準(zhǔn)確率 loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=1) print("Accuracy = {:.2f}".format(accuracy))
分類器的準(zhǔn)確率為:
3/3 [==============================] - 0s 591us/step - loss: 0.1031 - accuracy: 0.9733 Accuracy = 0.97
到此我們對(duì)tf.kears的使用有了一個(gè)基本的認(rèn)知,在接下來(lái)的課程中會(huì)給大家解釋神經(jīng)網(wǎng)絡(luò)以及在計(jì)算機(jī)視覺中的常用的CNN的使用。