《威利在哪里?》(Where’s Wally)是由英國插畫家馬丁·漢德福特(Martin Handford)創作的一套兒童繪本。這個書的目標就是在一張人山人海的圖片中找出一個特定的人物——威利(Wally)。“Where’s Wally”的商標已在28個國家進行了注冊,為方便語言翻譯,每一個國家都會給威利起一個新名字,最成功的是北美版的“Where’s Waldo”,在這里,威利改名成了沃爾多(Waldo)。
現在,機器學習博主Tadej Magajna另辟蹊徑,利用深度學習解開“威利在哪里”的問題。與傳統的計算機視覺圖像處理方法不同的是,它只使用了少數幾個標記出威利位置的圖片樣本,就訓練成了一套“尋找威利”的系統。
訓練過的圖像評估模型和檢測腳本發布在作者的GitHub repo上。
本文介紹了用TensorFlow物體檢測API訓練神經網絡、并用相應的Python腳本尋找威利的過程。大致分為以下幾步:
將圖片打標簽后創建數據集,其中標簽注明了威利在圖片中的位置,用x,y表示;
用TensorFlow物體檢測API獲取并配置神經網絡模型;
在數據集上訓練模型;
用導出的圖像測試模型;
開始前,請確保你已經按照說明安裝了TensorFlow物體檢測API。
創建數據集
雖說深度學習中最重要的環節是處理神經網絡,但不幸的是,數據科學家們總要花費大量時間準備訓練數據。
最簡單的機器學習問題最終得到的通常是一個標量(如數字檢測器)或是一個分類字符串。TensorFlow物體檢測API在訓練數據是則將上述兩個結果結合了起來。它由一系列圖像組成,并包含目標對象的標簽和他們在圖像中的位置。由于在二維圖像中,兩個點足以在對象周圍繪制邊界框,所以圖像的定位只有兩個點。
為了創建訓練集,我們需要準備一組Where’s Wally的插畫,并標出威利的位置。在此之前已經有人做出了一套解出威利在哪里的訓練集。
最右邊的四列描述了威利所在的位置
創建數據集的最后一步就是將標簽(.csv)和圖片(.jpeg)打包,存入單一二分類文件中(.tfrecord)。詳細過程可參考這里,訓練和評估過程也可以在作者的GitHub上找到。
準備模型
TensorFlow物體檢測API提供了一組性能不同的模型,它們要么精度高,但速度慢,要么速度快,但精度低。這些模型都在公開數據集上經過了預訓練。
雖然模型可以從頭開始訓練,隨機初始化網絡權重,但這可能需要幾周的時間。相反,這里作者采用了一種稱為遷移學習(Transfer Learning)的方法。
這種方法是指,用一個經常訓練的模型解決一般性問題,然后再將它重新訓練,用于解決我們的問題。也就是說,與其從頭開始訓練新模型,不如從預先訓練過的模型中獲取知識,將其轉移到新模型的訓練中,這是一種非常節省時間的方法。
作者使用了在COCO數據集上訓練過的搭載Inception v2模型的RCNN。該模型包含一個.ckpycheckpoint文件,可以利用它開始訓練。
配置文件下載完成后,請確保將“PATHTOBE_CONFIGURED”字段替換成指向checkpoint文件、訓練和評估的.tfrecord文件和標簽映射文件的路徑。
最后需要配置的文件是labels.txt映射文件,其中包含我們所有不同對象的標簽。由于我們尋找的都是同一個類型的對象(威利),所以標簽文件如下:
item {
id: 1
name: 'waldo'
}
最終應該得到:
一個有著checkpoint文件的預訓練模型;
經過訓練并評估的.tfrecord數據集;
標簽映射文件;
指向上述文件的配置文件。
然后就可以開始訓練啦。
訓練
TensorFlow物體檢測API提供了一個十分容易上手的Python腳本,可以在本地訓練模型。它位于models/research/object_detection中,可以通過以下命令運行:
python train.py --logtostderr --pipeline_config_path= PATH_TO_PIPELINE_CONFIG --train_dir=PATH_TO_TRAIN_DIR
PATH_TO_PIPELINE_CONFIG是通往配置文件的路徑,PATH_TO_TRAIN_DIR是新創建的directory,用來儲存checkpoint和模型。
train.py的輸出看起來是這樣:
用最重要的信息查看是否有損失,這是各個樣本在訓練或驗證時出現錯誤的總和。當然,你肯定希望它降得越低越好,因為如果它在緩慢地下降,就意味著你的模型正在學習(要么就是過擬合了你的數據……)。
你還可以用Tensorboard顯示更詳細的訓練數據。
腳本將在一定時間后自動存儲checkpoint文件,萬一計算機半路崩潰,你還可以恢復這些文件。也就是說,當你想完成模型的訓練時,隨時都可以終止腳本。
但是什么時候停止學習呢?一般是當我們的評估集損失停止減少或達到非常低的時候(在這個例子中低于0.01)。
測試
現在,我們可以將模型用于實際測試啦。
首先,我們需要從儲存的checkpoint中輸出一個推理圖(interference graph),利用的腳本如下:
python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH
產生的推理圖就是用來Python腳本用來找到威利的工具。
作者寫了幾個簡單目標定位的腳本,其中find_wally.py和find_wally_pretty.py都可以在他的GitHub上找到,并且運行起來也很簡單:
python find_wally.py
或者
python find_wally_pretty.py
不過當你在自己的模型或圖像上運行腳本時,記得改變model-path和image-path的變量。
結語
模型的表現出乎意料地好。它不僅從數據集中成功地找到了威利,還能在隨機從網上找的圖片中找到威利。
但是如果威利在圖中特別大,模型就找不到了。我們總覺得,不應該是目標物體越大越好找嗎?這樣的結果表明,作者用于訓練的圖像并不多,模型可能對訓練數據過度擬合了。
-
神經網絡
+關注
關注
42文章
4779瀏覽量
101172 -
python
+關注
關注
56文章
4807瀏覽量
85039 -
tensorflow
+關注
關注
13文章
329瀏覽量
60631
原文標題:如何用神經網絡“尋找威利”
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
高階API構建模型和數據集使用
嵌入式中的人工神經網絡的相關資料分享
輕量化神經網絡的相關資料下載
圖像預處理和改進神經網絡推理的簡要介紹
如何使用TensorFlow將神經網絡模型部署到移動或嵌入式設備上
傳統檢測、深度神經網絡框架、檢測技術的物體檢測算法全概述
![傳統<b class='flag-5'>檢測</b>、深度<b class='flag-5'>神經網絡</b>框架、<b class='flag-5'>檢測</b>技術的<b class='flag-5'>物體檢測</b>算法全概述](https://file.elecfans.com/web1/M00/CB/E2/pIYBAF-RL12APp_iAAAQtTRMCs8010.jpg)
評論