Open Neural Network Exchange (ONNX) 是由數個大廠合作定義的模型格式。而 ONNX Runtime 是一個可以執行 ONNX 模型的 library,它是微軟開發的。它支援多個平台,包含 Android。本文章將藉由建立一個企鵝偵測的 app 來介紹如何在 Android 上使用 ONNX Runtime。
Table of Contents
將 YOLOv8 模型輸出成 ONNX 格式
我們將使用以下文章中建立的 YOLOv8 模型。此模型可以用來偵測圖片或是影片中的企鵝寶寶。
YOLOv8 模型是 PyTorch 格式,所以我們要利用以下的指令,將它轉換成 ONNX 格式。執行後,它會輸出一個 best.onnx 檔案。我們將它重新命名為 baby_penguin.onnx。
YOLOv8Example % .venv/bin/yolo mode=export model=./runs/train/weights/best.pt format=onnx simplify=True
引入 ONNX Runtime 依賴至專案
要在 Android 上執行 ONNX Runtime,我們必須先將它引入至專案,如下。
dependencies { implementation("com.microsoft.onnxruntime:onnxruntime-android:0.9.0") implementation("com.microsoft.onnxruntime:onnxruntime-extensions-android:0.9.0") }
載入模型
首先,先將 baby_penguin.onnx 放置專案中的 assets
資料夾,並將模型載入成一個 InputStream
。呼叫 OrtEnvironment.getEnvironment()
取得一個 OrtEnvironment
,再呼叫 OrtEnvironment.createSession()
並傳入模型的 InputStream
來建立一個 OrtSession
。
class AppViewModel : ViewModel() { private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private var session: OrtSession? = null fun load(assets: AssetManager) { viewModelScope.launch(Dispatchers.IO) { ortEnv.use { assets.open( "baby_penguin.onnx").use { BufferedInputStream(it).use { bis -> session = ortEnv.createSession(bis.readBytes()) } } } // Load the image assets.open("image.jpg").use { _uiState.value = AppUiState(Bitmap.createBitmap(BitmapFactory.decodeStream(it))) } } } }
執行模型
在執行模型之前,我們必須要先將圖片轉換成模型接受的輸入格式。企鵝寶寶偵測模型的輸入圖片大小為 640 x 640,因此我們將圖片縮放成 640 x 640。
Android 的 Bitmap
預設格式為 ARGB_8888
,也是說一個 pixel 由 4 個 unsigned bytes 組成,且每個值為 0 ~ 255,其順序為 alpha、紅、綠、和藍。然而,YOLOv8 接受的格式為,一個 pixel 由 3 個 floats 組成,其順序為 藍、綠、和紅,且每個值為 0 ~ 1。因此,我們在 bitmapToFlatArray()
中將 RGB 轉換成 BGR。
然後,我們呼叫 OnnxTensor.createTensor()
將 Bitmap
轉換成一個 1 x 3 x 640 x 640 的 ONNX tensor。
呼叫 OrtSession.run()
來執行模型。
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 private const val BATCH_SIZE = 1 private const val PIXEL_SIZE = 3 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private var session: OrtSession? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = createTensor(scaledBitmap) val model = session ?: throw Exception("Model is not set") val inputName = model.inputNames.iterator().next() val output = model.run(Collections.singletonMap(inputName, input)) ... } } private fun createTensor(bitmap: Bitmap): OnnxTensor { ortEnv.use { return OnnxTensor.createTensor( ortEnv, FloatBuffer.wrap(bitmapToFlatArray(bitmap)), longArrayOf( BATCH_SIZE.toLong(), PIXEL_SIZE.toLong(), bitmap.width.toLong(), bitmap.height.toLong(), ) ) } } private fun bitmapToFlatArray(bitmap: Bitmap): FloatArray { val input = FloatArray(BATCH_SIZE * PIXEL_SIZE * bitmap.width * bitmap.height) val pixels = bitmap.width * bitmap.height val bitmapArray = IntArray(pixels) bitmap.getPixels(bitmapArray, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height) for (i in 0..<bitmap.width) { for (j in 0..<bitmap.height) { val idx = bitmap.height * i + j val pixelValue = bitmapArray[idx] input[idx] = (pixelValue shr 16 and 0xFF) / 255f input[idx + pixels] = (pixelValue shr 8 and 0xFF) / 255f input[idx + pixels * 2] = (pixelValue and 0xFF) / 255f } } return input } }
模型輸出的結果是一個 1 x 5 x 8400 的 ONNX tensor,如下圖。第二維度代表 bounding box,其大小是 5,分別是一個 bounding box 的 x、y、寬、高、和 confidence score。第三維度表示 bounding box 的個數,所以該輸出的結果包含 8400 個 bounding boxes。
我們可以在 OrtSession.Result
中取得其內部的 array,如下。這個 array 的格式會如上圖所示。
val outputArray = (result.get(0).value) as Array<*> val outputs = outputArray[0] as Array<FloatArray>
NMS
當你用 YOLOv8 的 framework 來執行 YOLOv8 模型後,你會得到已用 NMS 過濾過的 bounding boxes。然而,當我們用 ONNX Runtime 直接執行 YOLOv8 模型後,我們得到的是未用 NMS 過濾過的 bounding boxes,所以 bounding boxes 的數量才會達到 8400 個。因此,我們必須要自己實作 NMS 來過濾 bounding boxes。
如果你不熟悉 NMS 演算法的話,請參考以下文章。
以下是 NMS 演算法的實作。它不但會依據 NMS 演算法來過濾 bounding boxes,它還會將 bounding boxes 從 xywh
格式轉換成 xyxy
格式,也就是 bounding box 的左上點和右下點。
data class BoundingBox(val boundingBox: RectF, var score: Float, var clazz: Int) object PostProcessor { private const val CONFIDENCE_THRESHOLD = 0.25f private const val IOU_THRESHOLD = 0.45f private const val MAX_NMS = 30000 fun nms( result: OrtSession.Result, confidenceThreshold: Float = CONFIDENCE_THRESHOLD, iouThreshold: Float = IOU_THRESHOLD, maxNms: Int = MAX_NMS, scaleX: Float = 1f, scaleY: Float = 1f, ): List<BoundingBox> { val outputArray = (result.get(0).value) as Array<*> val outputs = outputArray[0] as Array<FloatArray> val results = mutableListOf<BoundingBox>() for (col in 0 until outputs[0].size) { var score = 0f var cls = 0 for (row in 4 until outputs.size) { if (outputs> score) { score = outputscls = row } } cls -= 4 if (score > confidenceThreshold) { val x = outputs[0]val y = outputs[1]val w = outputs[2]val h = outputs[3]val left = x - w / 2 val top = y - h / 2 val right = x + w / 2 val bottom = y + h / 2 val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom) val boxes = BoundingBox(rect, score, cls) results.add(boxes) } } return nms(results, iouThreshold, maxNms) } private fun nms(boxes: List<BoundingBox>, iouThreshold: Float, limit: Int): List<BoundingBox> { val selected = mutableListOf<BoundingBox>() val sortedBoxes = boxes.sortedWith { o1, o2 -> o1.score.compareTo(o2.score) } val active = BooleanArray(sortedBoxes.size) Arrays.fill(active, true) var numActive = active.size var done = false var i = 0 while (i < sortedBoxes.size && !done) { if (active[i]) { val boxA = sortedBoxes[i] selected.add(boxA) if (selected.size >= limit) break for (j in i + 1 until sortedBoxes.size) { if (active[j]) { val boxB = sortedBoxes[j] if (calculateIou(boxA.boundingBox, boxB.boundingBox) > iouThreshold) { active[j] = false numActive -= 1 if (numActive <= 0) { done = true break } } } } } i++ } return selected } private fun calculateIou(a: RectF, b: RectF): Float { val areaA = (a.right - a.left) * (a.bottom - a.top) if (areaA <= 0.0) return 0.0f val areaB = (b.right - b.left) * (b.bottom - b.top) if (areaB <= 0.0) return 0.0f val intersectionMinX = max(a.left, b.left) val intersectionMinY = max(a.top, b.top) val intersectionMaxX = min(a.right, b.right) val intersectionMaxY = min(a.bottom, b.bottom) val intersectionArea = max( intersectionMaxY - intersectionMinY, 0.0f ) * max(intersectionMaxX - intersectionMinX, 0.0f) return intersectionArea / (areaA + areaB - intersectionArea) } }
有了 NMS 演算法後,我們在呼叫 OrtSession.run()
之後,就可以呼叫 nms()
來過濾 bounding boxes,如下。
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 private const val BATCH_SIZE = 1 private const val PIXEL_SIZE = 3 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private var session: OrtSession? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = createTensor(scaledBitmap) val model = session ?: throw Exception("Model is not set") val inputName = model.inputNames.iterator().next() val output = model.run(Collections.singletonMap(inputName, input)) val boxes = PostProcessor.nms(output) _uiState.value = _uiState.value.copy(boxes = boxes) } } ... }
結語
ONNX Runtime 使用起來相當地簡單。本文章中比較困難的地方是,在執行 YOLOv8 模型後,我們必須要自己實作 NMS。除了 ONNX Runtime 之外,我們也可以在 Android 使用 PyTorch 來執行模型,請參考以下文章。
參考
- Object detection and pose estimation on mobile with YOLOv8, ONNX RUNTIME.
- ONNX RUNTIME Inference Example – Object Detection, Github.
2 comments
NMS的程式码中有贴入div的标签。opencv好像也是可以进行onnx模型的推理
感謝指正!剛剛看了一下,原來是 EnlighterJS plugin 的問題,所以我也無從修正,以下是正確的 mns() 函式的程式碼。 {()
fun nms(
result: OrtSession.Result,
confidenceThreshold: Float = CONFIDENCE_THRESHOLD,
iouThreshold: Float = IOU_THRESHOLD,
maxNms: Int = MAX_NMS,
scaleX: Float = 1f,
scaleY: Float = 1f,
): List
val outputArray = (result.get(0).value) as Array<*>
val outputs = outputArray[0] as Array
val results = mutableListOf
for (col in 0 until outputs[0].size) {
var score = 0f
var cls = 0
for (row in 4 until outputs.size) {
if (outputs[row][col] > score) {
score = outputs[row][col]
cls = row
}
}
cls -= 4
if (score > confidenceThreshold) {
val x = outputs[0][col]
val y = outputs[1][col]
val w = outputs[2][col]
val h = outputs[3][col]
val left = x – w / 2
val top = y – h / 2
val right = x + w / 2
val bottom = y + h / 2
val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom)
val boxes = BoundingBox(rect, score, cls)
results.add(boxes)
}
}
return nms(results, iouThreshold, maxNms)
}