在 Android 上使用 ONNX Runtime 執行 YOLOv8 模型

Photo by Sam Hojati on Unsplash
Photo by Sam Hojati on Unsplash
Open Neural Network Exchange (ONNX) 是由數個大廠合作定義的模型格式。而 ONNX Runtime 是一個可以執行 ONNX 模型的 library,它是微軟開發的。它支援多個平台,包含 Android。

Open Neural Network Exchange (ONNX) 是由數個大廠合作定義的模型格式。而 ONNX Runtime 是一個可以執行 ONNX 模型的 library,它是微軟開發的。它支援多個平台,包含 Android。本文章將藉由建立一個企鵝偵測的 app 來介紹如何在 Android 上使用 ONNX Runtime。

完整程式碼可以在 下載。

將 YOLOv8 模型輸出成 ONNX 格式

我們將使用以下文章中建立的 YOLOv8 模型。此模型可以用來偵測圖片或是影片中的企鵝寶寶。

YOLOv8 模型是 PyTorch 格式,所以我們要利用以下的指令,將它轉換成 ONNX 格式。執行後,它會輸出一個 best.onnx 檔案。我們將它重新命名為 baby_penguin.onnx。

YOLOv8Example % .venv/bin/yolo mode=export model=./runs/train/weights/best.pt format=onnx simplify=True

引入 ONNX Runtime 依賴至專案

要在 Android 上執行 ONNX Runtime,我們必須先將它引入至專案,如下。

dependencies {
    implementation("com.microsoft.onnxruntime:onnxruntime-android:0.9.0")
    implementation("com.microsoft.onnxruntime:onnxruntime-extensions-android:0.9.0")
}

載入模型

首先,先將 baby_penguin.onnx 放置專案中的 assets 資料夾,並將模型載入成一個 InputStream。呼叫 OrtEnvironment.getEnvironment() 取得一個 OrtEnvironment,再呼叫 OrtEnvironment.createSession() 並傳入模型的 InputStream 來建立一個 OrtSession

class AppViewModel : ViewModel() {
    private val _uiState = MutableStateFlow(AppUiState())
    val uiState: StateFlow<AppUiState> = _uiState.asStateFlow()

    private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
    private var session: OrtSession? = null

    fun load(assets: AssetManager) {
        viewModelScope.launch(Dispatchers.IO) {
            ortEnv.use {
                assets.open( "baby_penguin.onnx").use {
                    BufferedInputStream(it).use { bis ->
                        session = ortEnv.createSession(bis.readBytes())
                    }
                }
            }

            // Load the image
            assets.open("image.jpg").use {
                _uiState.value = AppUiState(Bitmap.createBitmap(BitmapFactory.decodeStream(it)))
            }
        }
    }
}

執行模型

在執行模型之前,我們必須要先將圖片轉換成模型接受的輸入格式。企鵝寶寶偵測模型的輸入圖片大小為 640 x 640,因此我們將圖片縮放成 640 x 640。

Android 的 Bitmap 預設格式為 ARGB_8888,也是說一個 pixel 由 4 個 unsigned bytes 組成,且每個值為 0 ~ 255,其順序為 alpha、紅、綠、和藍。然而,YOLOv8 接受的格式為,一個 pixel 由 3 個 floats 組成,其順序為 藍、綠、和紅,且每個值為 0 ~ 1。因此,我們在 bitmapToFlatArray() 中將 RGB 轉換成 BGR。

然後,我們呼叫 OnnxTensor.createTensor()Bitmap 轉換成一個 1 x 3 x 640 x 640 的 ONNX tensor。

呼叫 OrtSession.run() 來執行模型。

class AppViewModel : ViewModel() {
    companion object {
        private const val IMAGE_WIDTH = 640
        private const val IMAGE_HEIGHT = 640

        private const val BATCH_SIZE = 1
        private const val PIXEL_SIZE = 3
    }

    private val _uiState = MutableStateFlow(AppUiState())
    val uiState: StateFlow<AppUiState> = _uiState.asStateFlow()

    private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
    private var session: OrtSession? = null

    fun load(assets: AssetManager) {
        ...
    }

    fun infer(bitmap: Bitmap) {
        viewModelScope.launch(Dispatchers.Default) {
            val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false)
            val input = createTensor(scaledBitmap)
            val model = session ?: throw Exception("Model is not set")
            val inputName = model.inputNames.iterator().next()
            val output = model.run(Collections.singletonMap(inputName, input))
            ...
        }
    }

    private fun createTensor(bitmap: Bitmap): OnnxTensor {
        ortEnv.use {
            return OnnxTensor.createTensor(
                ortEnv,
                FloatBuffer.wrap(bitmapToFlatArray(bitmap)),
                longArrayOf(
                    BATCH_SIZE.toLong(),
                    PIXEL_SIZE.toLong(),
                    bitmap.width.toLong(),
                    bitmap.height.toLong(),
                )
            )
        }
    }

    private fun bitmapToFlatArray(bitmap: Bitmap): FloatArray {
        val input = FloatArray(BATCH_SIZE * PIXEL_SIZE * bitmap.width * bitmap.height)

        val pixels = bitmap.width * bitmap.height
        val bitmapArray = IntArray(pixels)
        bitmap.getPixels(bitmapArray, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
        for (i in 0..<bitmap.width) {
            for (j in 0..<bitmap.height) {
                val idx = bitmap.height * i + j
                val pixelValue = bitmapArray[idx]
                input[idx] = (pixelValue shr 16 and 0xFF) / 255f
                input[idx + pixels] = (pixelValue shr 8 and 0xFF) / 255f
                input[idx + pixels * 2] = (pixelValue and 0xFF) / 255f
            }
        }

        return input
    }
}

模型輸出的結果是一個 1 x 5 x 8400 的 ONNX tensor,如下圖。第二維度代表 bounding box,其大小是 5,分別是一個 bounding box 的 x、y、寬、高、和 confidence score。第三維度表示 bounding box 的個數,所以該輸出的結果包含 8400 個 bounding boxes。

PyTorch Output Tensor, 1 x 5 x 8400.
PyTorch Output Tensor, 1 x 5 x 8400.

我們可以在 OrtSession.Result 中取得其內部的 array,如下。這個 array 的格式會如上圖所示。

val outputArray = (result.get(0).value) as Array<*>
val outputs = outputArray[0] as Array<FloatArray>

NMS

當你用 YOLOv8 的 framework 來執行 YOLOv8 模型後,你會得到已用 NMS 過濾過的 bounding boxes。然而,當我們用 ONNX Runtime 直接執行 YOLOv8 模型後,我們得到的是未用 NMS 過濾過的 bounding boxes,所以 bounding boxes 的數量才會達到 8400 個。因此,我們必須要自己實作 NMS 來過濾 bounding boxes。

如果你不熟悉 NMS 演算法的話,請參考以下文章。

以下是 NMS 演算法的實作。它不但會依據 NMS 演算法來過濾 bounding boxes,它還會將 bounding boxes 從 xywh 格式轉換成 xyxy 格式,也就是 bounding box 的左上點和右下點。

data class BoundingBox(val boundingBox: RectF, var score: Float, var clazz: Int)

object PostProcessor {
    private const val CONFIDENCE_THRESHOLD = 0.25f
    private const val IOU_THRESHOLD = 0.45f
    private const val MAX_NMS = 30000

    fun nms(
        result: OrtSession.Result,
        confidenceThreshold: Float = CONFIDENCE_THRESHOLD,
        iouThreshold: Float = IOU_THRESHOLD,
        maxNms: Int = MAX_NMS,
        scaleX: Float = 1f,
        scaleY: Float = 1f,
    ): List<BoundingBox> {
        val outputArray = (result.get(0).value) as Array<*>
        val outputs = outputArray[0] as Array<FloatArray>
        val results = mutableListOf<BoundingBox>()

        for (col in 0 until outputs[0].size) {
            var score = 0f
            var cls = 0
            for (row in 4 until outputs.size) {
                if (outputs			
> score) { score = outputs
cls = row } } cls -= 4 if (score > confidenceThreshold) { val x = outputs[0]
val y = outputs[1]
val w = outputs[2]
val h = outputs[3]
val left = x - w / 2 val top = y - h / 2 val right = x + w / 2 val bottom = y + h / 2 val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom) val boxes = BoundingBox(rect, score, cls) results.add(boxes) } } return nms(results, iouThreshold, maxNms) } private fun nms(boxes: List<BoundingBox>, iouThreshold: Float, limit: Int): List<BoundingBox> { val selected = mutableListOf<BoundingBox>() val sortedBoxes = boxes.sortedWith { o1, o2 -> o1.score.compareTo(o2.score) } val active = BooleanArray(sortedBoxes.size) Arrays.fill(active, true) var numActive = active.size var done = false var i = 0 while (i < sortedBoxes.size && !done) { if (active[i]) { val boxA = sortedBoxes[i] selected.add(boxA) if (selected.size >= limit) break for (j in i + 1 until sortedBoxes.size) { if (active[j]) { val boxB = sortedBoxes[j] if (calculateIou(boxA.boundingBox, boxB.boundingBox) > iouThreshold) { active[j] = false numActive -= 1 if (numActive <= 0) { done = true break } } } } } i++ } return selected } private fun calculateIou(a: RectF, b: RectF): Float { val areaA = (a.right - a.left) * (a.bottom - a.top) if (areaA <= 0.0) return 0.0f val areaB = (b.right - b.left) * (b.bottom - b.top) if (areaB <= 0.0) return 0.0f val intersectionMinX = max(a.left, b.left) val intersectionMinY = max(a.top, b.top) val intersectionMaxX = min(a.right, b.right) val intersectionMaxY = min(a.bottom, b.bottom) val intersectionArea = max( intersectionMaxY - intersectionMinY, 0.0f ) * max(intersectionMaxX - intersectionMinX, 0.0f) return intersectionArea / (areaA + areaB - intersectionArea) } }

有了 NMS 演算法後,我們在呼叫 OrtSession.run() 之後,就可以呼叫 nms() 來過濾 bounding boxes,如下。

class AppViewModel : ViewModel() {
    companion object {
        private const val IMAGE_WIDTH = 640
        private const val IMAGE_HEIGHT = 640

        private const val BATCH_SIZE = 1
        private const val PIXEL_SIZE = 3
    }

    private val _uiState = MutableStateFlow(AppUiState())
    val uiState: StateFlow<AppUiState> = _uiState.asStateFlow()

    private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
    private var session: OrtSession? = null

    fun load(assets: AssetManager) {
        ...
    }

    fun infer(bitmap: Bitmap) {
        viewModelScope.launch(Dispatchers.Default) {
            val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false)
            val input = createTensor(scaledBitmap)
            val model = session ?: throw Exception("Model is not set")
            val inputName = model.inputNames.iterator().next()
            val output = model.run(Collections.singletonMap(inputName, input))
            val boxes = PostProcessor.nms(output)
            _uiState.value = _uiState.value.copy(boxes = boxes)
        }
    }
    ...
}

結語

ONNX Runtime 使用起來相當地簡單。本文章中比較困難的地方是,在執行 YOLOv8 模型後,我們必須要自己實作 NMS。除了 ONNX Runtime 之外,我們也可以在 Android 使用 PyTorch 來執行模型,請參考以下文章。

參考

2 comments
  1. NMS的程式码中有贴入div的标签。opencv好像也是可以进行onnx模型的推理

    1. 感謝指正!剛剛看了一下,原來是 EnlighterJS plugin 的問題,所以我也無從修正,以下是正確的 mns() 函式的程式碼。
      fun nms(
      result: OrtSession.Result,
      confidenceThreshold: Float = CONFIDENCE_THRESHOLD,
      iouThreshold: Float = IOU_THRESHOLD,
      maxNms: Int = MAX_NMS,
      scaleX: Float = 1f,
      scaleY: Float = 1f,
      ): List {
      val outputArray = (result.get(0).value) as Array<*>
      val outputs = outputArray[0] as Array
      val results = mutableListOf()

      for (col in 0 until outputs[0].size) {
      var score = 0f
      var cls = 0
      for (row in 4 until outputs.size) {
      if (outputs[row][col] > score) {
      score = outputs[row][col]
      cls = row
      }
      }
      cls -= 4

      if (score > confidenceThreshold) {
      val x = outputs[0][col]
      val y = outputs[1][col]
      val w = outputs[2][col]
      val h = outputs[3][col]

      val left = x – w / 2
      val top = y – h / 2
      val right = x + w / 2
      val bottom = y + h / 2

      val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom)
      val boxes = BoundingBox(rect, score, cls)
      results.add(boxes)
      }
      }

      return nms(results, iouThreshold, maxNms)
      }

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

You May Also Like
Photo by Hans-Jurgen Mager on Unsplash
Read More

Kotlin Coroutine 教學

Kotlin 的 coroutine 是用來取代 thread。它不會阻塞 thread,而且還可以被取消。Coroutine core 會幫你管理 thread 的數量,讓你不需要自行管理,這也可以避免不小心建立過多的 thread。
Read More