深度學(xué)習實(shí)現場(chǎng)景字符識別模型|代碼干貨
文字是人從日常交流中語(yǔ)音中演化出來(lái),用來(lái)記錄信息的重要工具。文字對于人類(lèi)意義非凡,以中國為例,中國地大物博,各個(gè)地方的口音都不統一,但是人們使用同一套書(shū)寫(xiě)體系,使得即使遠隔千里,我們依然能夠通過(guò)文字進(jìn)行無(wú)障礙的溝通。文字也能夠跨越時(shí)空,給予了我們了解古人的通道。隨著(zhù)計算機的誕生,文字也進(jìn)行了數字化的進(jìn)程,但是不同于人類(lèi),讓計算機能夠正確地進(jìn)行字符識別是一個(gè)復雜又艱巨但意義重大的工作。從計算機誕生開(kāi)始,無(wú)數的研究者在這方面做了很多工作與嘗試,但面臨的困難艱巨。
其中場(chǎng)景文字識別中主要面臨的困難是:
(1)場(chǎng)景復雜變化很大;
(2)字體形態(tài)顏色多變;
(3)光照條件變化大;
(4)文字排列方式不確定;
(5)文本行與文本行之間的距離,大小格式,字體變化大。
而深度學(xué)習的引入,使得在我們在復雜場(chǎng)景下進(jìn)行字符識別更為便利。
本項目通過(guò)使用pytorch搭建resnet遷移學(xué)習模型實(shí)現對復雜場(chǎng)景下字符的識別。其模型訓練過(guò)程如下圖可見(jiàn):
# 1.基本介紹#
文字是人從日常交流中語(yǔ)音中演化出來(lái),用來(lái)記錄信息的重要工具。文字對于人類(lèi)意義非凡,以中國為例,中國地大物博,各個(gè)地方的口音都不統一,但是人們使用同一套書(shū)寫(xiě)體系,使得即使遠隔千里,我們依然能夠通過(guò)文字進(jìn)行無(wú)障礙的溝通。文字也能夠跨越時(shí)空,給予了我們了解古人的通道。隨著(zhù)計算機的誕生,文字也進(jìn)行了數字化的進(jìn)程,但是不同于人類(lèi),讓計算機能夠正確地進(jìn)行字符識別是一個(gè)復雜又艱巨但意義重大的工作。從計算機誕生開(kāi)始,無(wú)數的研究者在這方面做了很多工作與嘗試,但面臨的困難艱巨。
1.1 環(huán)境要求
本次環(huán)境使用的是python3.6.5+windows平臺。
主要用的庫有:Opencv-python模塊、Pillow模塊、PyTorch模塊。
Opencv-python模塊:
opencv-python是一個(gè)Python綁定庫,旨在解決計算機視覺(jué)問(wèn)題。其使用Numpy,這是一個(gè)高度優(yōu)化的數據庫操作庫,具有MATLAB風(fēng)格的語(yǔ)法。所有Opencv數組結構都轉換為Numpy數組。這也使得與使用Numpy的其他庫(如Scipy和Matplotlib)集成更容易。
Pillow模塊:
Pillow是Python里的圖像處理庫,它提供了了廣泛的文件格式支持和強大的圖像處理能力,主要包括圖像儲存、圖像顯示、格式轉換以及基本的圖像處理操作等。
PyTorch模塊
PyTorch是一個(gè)基于Torch的Python開(kāi)源機器學(xué)習庫,用于自然語(yǔ)言處理等應用程序。它主要由Facebookd的人工智能小組開(kāi)發(fā),不僅能夠實(shí)現強大的GPU加速,同時(shí)還支持動(dòng)態(tài)神經(jīng)網(wǎng)絡(luò ),這一點(diǎn)是現在很多其他的主流框架都不支持的。PyTorch還提供了兩個(gè)高級功能:1.具有強大的GPU加速的張量計算2.包含自動(dòng)求導系統的深度神經(jīng)網(wǎng)絡(luò ) 除了Facebook之外,Twitter、GMU和Salesforce等機構都采用了PyTorch。
1.2 遷移模型
遷移學(xué)習是通過(guò)從已學(xué)習的相關(guān)任務(wù)中轉移知識來(lái)改進(jìn)學(xué)習的新任務(wù),雖然大多數機器學(xué)習算法都是為了解決單個(gè)任務(wù)而設計的,但是促進(jìn)遷移學(xué)習的算法的開(kāi)發(fā)是機器學(xué)習社區持續關(guān)注的話(huà)題。
由下圖可以看出遷移學(xué)習和傳統機器學(xué)習的區別,在傳統機器學(xué)習的學(xué)習過(guò)程中,我們試圖單獨學(xué)習每一個(gè)學(xué)習任務(wù),即生成多個(gè)學(xué)習系統;而在遷移學(xué)習中,我們試圖將在前幾個(gè)任務(wù)上學(xué)到的知識轉移到目前的學(xué)習任務(wù)上,從而將其結合起來(lái)。
# 2.算法模型#
在這里我們使用的是resnet模型對圖像進(jìn)行特征提取。其中圖像特征提取通常使用卷積神經(jīng)網(wǎng)絡(luò )進(jìn)行特征學(xué)習,由于字符識別相較于物體分類(lèi)的不同,通常不會(huì )完全照搬分類(lèi)網(wǎng)絡(luò )來(lái)直接進(jìn)行圖形特征提取,會(huì )在分類(lèi)網(wǎng)絡(luò )的基礎上為了適應目標任務(wù)的改進(jìn)。
由于卷積神經(jīng)網(wǎng)絡(luò )會(huì )受到感受野的限制,因此提出了需要使用序列特征提取模型對特征進(jìn)行建模,學(xué)習卷積神經(jīng)網(wǎng)絡(luò )提取到的圖像特征之間的上下文關(guān)系。
2.1 數據集準備
在這里我們將訓練的數據集分成了訓練集、測試集和驗證集三部分。其中準備的數據集如下:
2.2 數據處理
為了保證每次運行模型效果基本相同,這里設置隨機種子,同時(shí)torch.backends.cudnn.deterministic將這個(gè)flag置為T(mén)rue。然后進(jìn)行圖像變換transforms,shffule=True在表示不同批次的數據遍歷時(shí),打亂順序。num_workers=0表示使用0個(gè)子進(jìn)程來(lái)加載數據。代碼如下:
SVHNDataset(train_path, train_label, transforms.Compose([ # 圖像尺寸變換(resize) ——transforms.Resize transforms.Resize((64, 128)), # 隨機裁剪:transforms.RandomCrop。size(sequence 或int) transforms.RandomCrop((60, 120)), # 修改亮度、對比度和飽和度:transforms.ColorJitter。亮度。對比度。飽和度。。。 transforms.ColorJitter(0.3, 0.3, 0.2), # 隨機旋轉:transforms.RandomRotation。degrees(sequence 或float或int) -要選擇的度數范圍 transforms.RandomRotation(5), # 將PIL Image或者 ndarray 轉換為tensor,并且歸一化至[0-1] transforms.ToTensor(), # 標準化:transforms.Normalize。用平均值和標準偏差歸一化張量圖像。mean每個(gè)通道的均值序列。std每個(gè)通道的標準偏差序列。 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])), batch_size=40, shuffle=True, num_workers=0 )
2.3 resnet模型搭建
這里采用的resnet18模型是由17個(gè)卷積層(conv)+1個(gè)全連接層(fc)構成。其中使用resnet模型的主要優(yōu)勢在于,當逐漸增加神經(jīng)網(wǎng)絡(luò )的深度時(shí),網(wǎng)絡(luò )難以學(xué)習恒等函數的參數,導致最后的訓練效果往往達不到預期,也會(huì )影響網(wǎng)絡(luò )性能。殘差網(wǎng)絡(luò )學(xué)習恒等函數比較容易,可將添加的網(wǎng)絡(luò )層看成一個(gè)個(gè)殘差塊。例如,一個(gè)20層的普通網(wǎng)絡(luò ),每?jì)蓪又g通過(guò)跳躍連接構成一個(gè)殘差塊,那么這個(gè)普通網(wǎng)絡(luò )就成為一個(gè)由10個(gè)殘差塊構成的殘差網(wǎng)絡(luò )。網(wǎng)絡(luò )性能不僅沒(méi)有下降,而且甚至有所提高。普通網(wǎng)絡(luò )轉化為殘差網(wǎng)絡(luò )也比較容易,只需要加入殘差塊即可。殘差網(wǎng)絡(luò )大大提高了網(wǎng)絡(luò )層數,通過(guò)殘差映射的方式進(jìn)行擬合,簡(jiǎn)單易操作,同時(shí)提高了準確率。
設置resnet18網(wǎng)絡(luò )模型,進(jìn)行遷移學(xué)習,保留resnet18網(wǎng)絡(luò )的卷積網(wǎng)絡(luò )部分,并保留預訓練參數。然后設計自適應平均池化函數,即不管之前的特征圖尺寸為多少,只要設置為(1,1),那么最終特征圖大小都為(1,1),然后把resnet18模型除了最后一個(gè)全連接層之外的各個(gè)網(wǎng)絡(luò )層提取出來(lái),并設置5個(gè)全連接層,分別對應5個(gè)可能的街道字符的識別。
def __init__(self): super(SVHN_Model1, self).__init__() model_conv = models.resnet18(pretrained=True) model_conv.avgpool = nn.AdaptiveAvgPool2d(1) model_conv = nn.Sequential(*list(model_conv.children())[:-1]) self.cnn = model_conv self.fc1 = nn.Linear(512, 11) self.fc2 = nn.Linear(512, 11) self.fc3 = nn.Linear(512, 11) self.fc4 = nn.Linear(512, 11) self.fc5 = nn.Linear(512, 11) def forward(self, img): feat = self.cnn(img) feat = feat.view(feat.shape[0], -1) c1 = self.fc1(feat) c2 = self.fc2(feat) c3 = self.fc3(feat) c4 = self.fc4(feat) c5 = self.fc5(feat) # c6 = self.fc6(feat) return c1, c2, c3, c4, c5 # , c6
完整代碼鏈接:
https://pan.baidu.com/s/1UpIq9XSlWxSotE0fama3Vw 提取碼:gcwu
*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。