<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>
"); //-->

博客專(zhuān)欄

EEPW首頁(yè) > 博客 > 為機器學(xué)習模型設置最佳閾值:0.5是二元分類(lèi)的最佳閾值嗎

為機器學(xué)習模型設置最佳閾值:0.5是二元分類(lèi)的最佳閾值嗎

發(fā)布人:數據派THU 時(shí)間:2022-12-23 來(lái)源:工程師 發(fā)布文章

對于二元分類(lèi),分類(lèi)器輸出一個(gè)實(shí)值分數,然后通過(guò)對該值進(jìn)行閾值的區分產(chǎn)生二元的相應。例如,邏輯回歸輸出一個(gè)概率(一個(gè)介于0.0和1.0之間的值);得分等于或高于0.5的觀(guān)察結果產(chǎn)生正輸出(許多其他模型默認使用0.5閾值)。


但是使用默認的0.5閾值是不理想的。在本文中,我將展示如何從二元分類(lèi)器中選擇最佳閾值。本文將使用Ploomber并行執行我們的實(shí)驗,并使用sklearn-evaluation生成圖。


圖片


這里以訓練邏輯回歸為例。假設我們正在開(kāi)發(fā)一個(gè)內容審核系統,模型標記包含有害內容的帖子(圖片、視頻等);然后,人工會(huì )查看并決定內容是否被刪除。


構建簡(jiǎn)單的二元分類(lèi)器


下面的代碼片段訓練我們的分類(lèi)器:

 import matplotlib.pyplot as plt import matplotlib as mpl from sklearn import datasets from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn_evaluation.plot import ConfusionMatrix
# matplotlib settings mpl.rcParams['figure.figsize'] = (4, 4) mpl.rcParams['figure.dpi'] = 150
# create sample dataset X, y = datasets.make_classification(1000, 10, n_informative=5, class_sep=0.4) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
# fit model clf = LogisticRegression() _ = clf.fit(X_train, y_train)


現在讓我們對測試集進(jìn)行預測,并通過(guò)混淆矩陣評估性能:

 # predict on the test set y_pred = clf.predict(X_test)
# plot confusion matrix cm_dot_five = ConfusionMatrix(y_test, y_pred) cm_dot_five


圖片


混淆矩陣總結了模型在四個(gè)區域的性能:


圖片


我們希望在左上和右下象限中獲得盡可能多的觀(guān)察值(從測試集),因為這些是我們的模型得到正確的觀(guān)察值。其他象限是模型錯誤。


改變模型的閾值將改變混淆矩陣中的值。在前面的示例中,使用clf.predict,返回一個(gè)二元響應(即使用0.5作為閾值);但是我們可以使用clf.predict_proba函數獲取原始概率并使用自定義閾值:

 y_score = clf.predict_proba(X_test)


我們可以通過(guò)設置一個(gè)較低的閾值(即標記更多的帖子為有害的)來(lái)讓我們的分類(lèi)器更具侵略性,并創(chuàng )建一個(gè)新的混淆矩陣:

 cm_dot_four = ConfusionMatrix(y_score[:, 1] >= 0.4, y_pred)


sklearn-evaluation庫可以輕松比較兩個(gè)矩陣:

 cm_dot_five + cm_dot_four


圖片


三角形的上面來(lái)自0.5的閾值,下面來(lái)自0.4的閾值:


  • 兩個(gè)模型對相同數量的觀(guān)測結果都預測為0(這是一個(gè)巧合)。0.5閾值:(90 + 56 = 146)。0.4閾值:(78 + 68 = 146)

  • 降低閾值會(huì )導致更多的假陰性(從56例降至68例)

  • 降低閾值將大大增加真陽(yáng)性(從92例增加154例)


微小的閾值變化極大地影響了混淆矩陣。我們只分析了兩個(gè)閾值。那么如果能夠分析跨所有值的模型性能,我們就可以好地理解閾值動(dòng)態(tài)。但是在此之前,需要定義用于模型評估的新指標。


到目前為止,我們都是用絕對數字來(lái)評估我們的模型。為了便于比較和評估,我們現在將定義兩個(gè)標準化指標(它們的值在0.0和1.0之間)。


精度precision是標記的觀(guān)察事件的比例(例如,我們的模型認為有害的帖子,它們是有害的)。召回 recall是我們的模型檢索到的實(shí)際事件的比例(即,從所有有害的帖子中,我們能夠檢測到它們的哪個(gè)比例)。


圖片


以上圖片來(lái)自維基百科,可以很好的說(shuō)明這兩個(gè)指標是如何計算的,精確度和召回率都是比例關(guān)系,所以它們都是0比1的比例。


運行實(shí)驗


我們將根據幾個(gè)閾值獲得精度、召回率和其他統計信息,以便更好地理解閾值如何影響它們。我們還將多次重復這個(gè)實(shí)驗來(lái)測量可變性。


本節中的命令都是bash命令。需要在終端中執行它們,如果使用Jupyter可以使用%%sh魔法命令。


這里使用Ploomber Cloud運行我們的實(shí)驗。因為它允許我們并行運行實(shí)驗并快速檢索結果。


創(chuàng )建了一個(gè)適合一個(gè)模型的Notebook,并為幾個(gè)閾值計算統計數據,并行執行同一個(gè)Notebook20次。

 curl -O  https://raw.githubusercontent.com/ploomber/posts/master/threshold/fit.  ipynb?utm_source=medium&utm_medium=blog&utm_campaign=threshold


讓執行這個(gè)Notebook(文件中的配置會(huì )告訴Ploomber Cloud并行運行它20次):

 ploomber cloud nb fit.ipynb


幾分鐘后,我們就會(huì )看到的20個(gè)實(shí)驗完成了:

 ploomber cloud status @latest --summary
status     count -------- ------- finished       20
Pipeline finished. Check outputs: $ ploomber cloud products


讓我們下載存儲在.csv文件中的實(shí)驗結果:

 ploomber cloud download 'threshold-selection/*.csv' --summary


可視化實(shí)驗結果


將加載所有實(shí)驗的結果,并一次性將它們繪制出來(lái)。


 from glob import glob
import pandas as pd import numpy as np paths = glob('threshold-selection/**/*.csv') metrics = [pd.read_csv(path) for path in paths]
for idx, df in enumerate(metrics):        plt.plot(df.threshold, df.precision, color='blue', alpha=0.2,              label='precision' if idx == 0 else None)    plt.plot(df.threshold, df.recall, color='green', alpha=0.2,              label='recall' if idx == 0 else None)    plt.plot(df.threshold, df.f1, color='orange', alpha=0.2,              label='f1' if idx == 0 else None)

plt.grid() plt.legend() plt.xlabel('Threshold') plt.ylabel('Metric value')
for handle in plt.legend().legendHandles:    handle.set_alpha(1)
ax = plt.twinx()
for idx, df in enumerate(metrics):    ax.plot(df.threshold, df.n_flagged,            label='flagged' if idx == 0 else None,            color='red', alpha=0.2)
plt.ylabel('Flagged') ax.legend(loc=0) ax.legend().legendHandles[0].set_alpha(1)


圖片


左邊的刻度(從0到1)是我們的三個(gè)指標:精度、召回率和F1。F1分為精度與查全率的調和平均值,F1分的最佳值為1.0,最差值為0.0;F1對精度和召回率都是相同對待的,所以你可以看到它在兩者之間保持平衡。如果你正在處理一個(gè)精確度和召回率都很重要的用例,那么最大化F1是一種可以幫助你優(yōu)化分類(lèi)器閾值的方法。


這里還包括一條紅色曲線(xiàn)(右側的比例),顯示我們的模型標記為有害內容的案例數量。


在這個(gè)的內容審核示例中,可能有X個(gè)的工作人員來(lái)人工審核模型標記的有害帖子,但是他們人數是有限的,因此考慮標記帖子的總數可以幫助我們更好地選擇閾值:例如每天只能檢查5000個(gè)帖子,那么模型找到10,000帖并不會(huì )帶來(lái)任何的提高。如果我人工每天可以處理10000貼,但是模型只標記了100貼,那么顯然也是浪費的。


當設置較低的閾值時(shí),有較高的召回率(我們檢索了大部分實(shí)際上有害的帖子),但精度較低(包含了許多無(wú)害的帖子)。如果我們提高閾值,情況就會(huì )反轉:召回率下降(錯過(guò)了許多有害的帖子),但精確度很高(大多數標記的帖子都是有害的)。


所以在為我們的二元分類(lèi)器選擇閾值時(shí),我們必須在精度或召回率上妥協(xié),因為沒(méi)有一個(gè)分類(lèi)器是完美的。我們來(lái)討論一下如何推理選擇合適的閾值。


選擇最佳閾值


右邊的數據會(huì )產(chǎn)生噪聲(較大的閾值)。需要稍微清理一下,我們將重新創(chuàng )建這個(gè)圖,我們將繪制2.5%、50%和97.5%的百分位數,而不是繪制所有值。

 shape = (df.shape[0], len(metrics)) precision = np.zeros(shape) recall = np.zeros(shape) f1 = np.zeros(shape) n_flagged = np.zeros(shape) for i, df in enumerate(metrics):    precision[:, i] = df.precision.values    recall[:, i] = df.recall.values    f1[:, i] = df.f1.values    n_flagged[:, i] = df.n_flagged.values precision_ = np.quantile(precision, q=0.5, axis=1) recall_ = np.quantile(recall, q=0.5, axis=1) f1_ = np.quantile(f1, q=0.5, axis=1) n_flagged_ = np.quantile(n_flagged, q=0.5, axis=1) plt.plot(df.threshold, precision_, color='blue', label='precision') plt.plot(df.threshold, recall_, color='green', label='recall') plt.plot(df.threshold, f1_, color='orange', label='f1')
plt.fill_between(df.threshold, precision_interval[0],                  precision_interval[1], color='blue',                  alpha=0.2)
plt.fill_between(df.threshold, recall_interval[0],                  recall_interval[1], color='green',                  alpha=0.2)

plt.fill_between(df.threshold, f1_interval[0],                  f1_interval[1], color='orange',                  alpha=0.2) plt.xlabel('Threshold') plt.ylabel('Metric value') plt.legend()
ax = plt.twinx() ax.plot(df.threshold, n_flagged_, color='red', label='flagged') ax.fill_between(df.threshold, n_flagged_interval[0],                n_flagged_interval[1], color='red',                alpha=0.2)
ax.legend(loc=3)
plt.ylabel('Flagged') plt.grid()


圖片


我們可以根據自己的需求選擇閾值,例如檢索盡可能多的有害帖子(高召回率)是否更重要?還是要有更高的確定性,我們標記的必須是有害的(高精度)?


如果兩者都同等重要,那么在這些條件下優(yōu)化的常用方法就是最大化F-1分數:



















 idx = np.argmax(f1_) prec_lower, prec_upper = precision_interval[0][idx], precision_interval[1][idx] rec_lower, rec_upper = recall_interval[0][idx], recall_interval[1][idx] threshold = df.threshold[idx]
print(f'Max F1 score: {f1_[idx]:.2f}') print('Metrics when maximizing F1 score:') print(f' - Threshold: {threshold:.2f}') print(f' - Precision range: ({prec_lower:.2f}, {prec_upper:.2f})') print(f' - Recall range: ({rec_lower:.2f}, {rec_upper:.2f})')
#結果 Max F1 score: 0.71 Metrics when maximizing F1 score:  - Threshold: 0.26  - Precision range: (0.58, 0.61)  - Recall range: (0.86, 0.90)


在很多情況下很難決定這個(gè)折中,所以加入一些約束條件會(huì )有一些幫助。


假設我們有10個(gè)人審查有害的帖子,他們可以一起檢查5000個(gè)。那么讓我們看看指標,如果我們修改了閾值,讓它標記了大約5000個(gè)帖子:


















 idx = np.argmax(n_flagged_ <= 5000)
prec_lower, prec_upper = precision_interval[0][idx], precision_interval[1][idx] rec_lower, rec_upper = recall_interval[0][idx], recall_interval[1][idx] threshold = df.threshold[idx]
print('Metrics when limiting to a maximum of 5,000 flagged events:') print(f' - Threshold: {threshold:.2f}') print(f' - Precision range: ({prec_lower:.2f}, {prec_upper:.2f})') print(f' - Recall range: ({rec_lower:.2f}, {rec_upper:.2f})')
# 結果 Metrics when limiting to a maximum of 5,000 flagged events:  - Threshold: 0.82  - Precision range: (0.77, 0.81)  - Recall range: (0.25, 0.36)


如果需要進(jìn)行匯報,我們可以在在展示結果時(shí)展示一些替代方案:比如在當前約束條件下(5000個(gè)帖子)的模型性能,以及如果我們增加團隊(比如通過(guò)增加一倍的規模),我們可以做得更好。


總結


二元分類(lèi)器的最佳閾值是針對業(yè)務(wù)結果進(jìn)行優(yōu)化并考慮到流程限制的閾值。通過(guò)本文中描述的過(guò)程,你可以更好地為用例決定最佳閾值。


如果你對這篇文章有任何問(wèn)題,請隨時(shí)留言。


另外,Ploomber Cloud!提供一些免費的算力!如果你需要一些免費的服務(wù)可以試試它。



*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專(zhuān)區

關(guān)閉
国产精品自在自线亚洲|国产精品无圣光一区二区|国产日产欧洲无码视频|久久久一本精品99久久K精品66|欧美人与动牲交片免费播放
<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>