<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>
"); //-->

博客專(zhuān)欄

EEPW首頁(yè) > 博客 > 獨家 | 使用TensorFlow 2創(chuàng )建自定義損失函數

獨家 | 使用TensorFlow 2創(chuàng )建自定義損失函數

發(fā)布人:數據派THU 時(shí)間:2021-05-26 來(lái)源:工程師 發(fā)布文章

作者:Arjun Sarkar

翻譯:陳之炎

校對:歐陽(yáng)錦

1.png

神經(jīng)網(wǎng)絡(luò )利用訓練數據,將一組輸入映射成一組輸出,它通過(guò)使用某種形式的優(yōu)化算法,如梯度下降、隨機梯度下降、AdaGrad、AdaDelta等等來(lái)實(shí)現,其中最新的算法包括Adam、Nadam或RMSProp。梯度下降中的“梯度”是指誤差梯度。每次迭代之后,網(wǎng)絡(luò )將其預測輸出與實(shí)際輸出進(jìn)行比較,然后計算出“誤差”。

通常,對于神經(jīng)網(wǎng)絡(luò ),尋求的是將誤差最小化。將誤差最小化的目標函數通常稱(chēng)之為成本函數或損失函數,由“損失函數”計算出的值稱(chēng)為“損失”。在各種問(wèn)題中使用的典型損失函數有:

均方誤差;

均方對數誤差;

二元交叉熵;

分類(lèi)交叉熵;

稀疏分類(lèi)交叉熵。

Tensorflow已經(jīng)包含了上述損失函數,直接調用它們即可,如下所示:

1. 將損失函數當作字符串進(jìn)行調用

model.compile (loss = ‘binary_crossentropy’,optimizer = ‘a(chǎn)dam’, metrics = [‘a(chǎn)ccuracy’])

2. 將損失函數當作對象進(jìn)行調用

from tensorflow.keras.losses importmean_squared_error
model.compile(loss = mean_squared_error,optimizer=’sgd’)

將損失函數當作對象進(jìn)行調用的優(yōu)點(diǎn)是可以在損失函數中傳遞閾值等參數。

from tensorflow.keras.losses import mean_squared_error
model.compile (loss=mean_squared_error(param=value),optimizer = ‘sgd’)

利用現有函數創(chuàng )建自定義損失函數:

利用現有函數創(chuàng )建損失函數,首先需要定義損失函數,它將接受兩個(gè)參數,y_true(真實(shí)標簽/輸出)和y_pred(預測標簽/輸出)。

def loss_function(y_true, y_pred):
***some calculation***
return loss

創(chuàng )建均方誤差損失函數 (RMSE):

定義損失函數名稱(chēng)-my_rmse。目的是返回目標(y_true)與預測(y_pred)之間的均方誤差。

RMSE的公式為:

2.jpg

  • 誤差:真實(shí)標簽與預測標簽之間的差異。

  • sqr_error:誤差的平方。

  • mean_sqr_error:誤差平方的均值。

  • sqrt_mean_sqr_error:誤差平方均值的平方根(均方根誤差)。

3.png

創(chuàng )建Huber損失函數:

4.png圖2:Huber損失函數(綠色)和平方誤差損失函數(藍色)(來(lái)源:Qwertyus— Own work,CCBY-SA4.0,https://commons.wikimedia.org/w/index.php?curid=34836380)

Huber損失函數的計算公式:

5.jpg

在此處,δ是閾值,a是誤差(將計算出a,即實(shí)際標簽和預測標簽之間的差異)。

當|a|≤δ時(shí),loss = 1/2*(a)2

當 |a|>δ時(shí),loss = δ(|a|—(1/2)*δ)

源代碼:

6.png

詳細說(shuō)明:

首先,定義一個(gè)函數—— my huber loss,它需要兩個(gè)參數:y_true和y_pred,

設置閾值threshold = 1。

計算誤差error a = y_true-y_pred。接下來(lái),檢查誤差的絕對值是否小于或等于閾值,is_small_error返回一個(gè)布爾值(真或假)。

當|a|≤δ時(shí),loss= 1/2*(a)2,計算small_error_loss, 誤差的平方除以2。否則,當|a| >δ時(shí),則損失等于δ(|a|-(1/2)*δ),用big_error_loss來(lái)計算這個(gè)值。

最后,在返回語(yǔ)句中,首先檢查is_small_error是真還是假,如果它為真,函數返回small_error_loss,否則返回big_error_loss,使用tf.where來(lái)實(shí)現。

可以使用下述代碼來(lái)編譯模型:

7.png

在上述代碼中,將閾值設為1。

如果需要調整超參數(閾值),并在編譯過(guò)程中加入一個(gè)新的閾值的話(huà),必須使用wrapper函數進(jìn)行封裝,也就是說(shuō),將損失函數封裝成另一個(gè)外部函數。在這里需要用到封裝函數(wrapper function),因為損失函數在默認情況下只能接受y_true和y_pred值,而且不能向原始損失函數添加任何其他參數。

使用封裝后的Huber損失函數

封裝函數的源代碼:

8.png

此時(shí),閾值不是硬編碼,可以在模型編譯過(guò)程中傳遞該閾值。

9.png

使用類(lèi)實(shí)現Huber損失函數(OOP)

10.png

其中,MyHuberLoss是類(lèi)名稱(chēng),隨后從tensorflow.keras.losses繼承父類(lèi)“Loss”, MyHuberLoss繼承了Loss類(lèi),之后可以將MyHuberLoss當作損失函數來(lái)使用。

__init__   初始化該類(lèi)中的對象。執行類(lèi)實(shí)例化對象時(shí)調用函數,init函數返回閾值,調用函數得到y_true和y_pred參數,將閾值聲明為一個(gè)類(lèi)變量,可以給它賦一個(gè)初始值。

在__init__函數中,將閾值設置為self.threshold。在調用函數中,self.threshold引用所有的閾值類(lèi)變量。在model.compile中使用這個(gè)損失函數:

11.png

創(chuàng )建對比性損失(用于Siamese網(wǎng)絡(luò )):

12.jpg

Siamese網(wǎng)絡(luò )可以用來(lái)比較兩幅圖像是否相似,Siamese網(wǎng)絡(luò )使用的損失函數為對比性損失。

在上文的公式中,Y_true是關(guān)于圖像相似性細節的張量,如果圖像相似,則為1,如果圖像不相似,則為0。

D是圖像對之間的歐氏距離的張量。邊際為一個(gè)常量,用它來(lái)設置將圖像區別為相似或不同的最小距離。如果為Y_true=1,則方程的第一部分為D2,第二部分為0,所以,當Y_true接近1時(shí),D2的權重則更重。

如果Y_true=0,則方程的第一部分變?yōu)?,第二部分會(huì )產(chǎn)生一些結果,這給了最大項更多的權重,給了D平方項更少的權重,此時(shí),最大項在損失計算中占了優(yōu)勢。

使用封裝器函數實(shí)現對比損失函數:

13.png

結論

在Tensorflow中沒(méi)有的損失函數都可以利用函數、包裝函數或類(lèi)似的類(lèi)來(lái)創(chuàng )建。

原文標題:

Creating custom Loss functionsusing TensorFlow 2

原文鏈接:

https://towardsdatascience.com/creating-custom-loss-functions-using-tensorflow-2-96c123d5ce6c

*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。

負離子發(fā)生器相關(guān)文章:負離子發(fā)生器原理
離子色譜儀相關(guān)文章:離子色譜儀原理


關(guān)鍵詞: Python

相關(guān)推薦

技術(shù)專(zhuān)區

關(guān)閉
国产精品自在自线亚洲|国产精品无圣光一区二区|国产日产欧洲无码视频|久久久一本精品99久久K精品66|欧美人与动牲交片免费播放
<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>