Executing YOLOv8 Models on Android Using ONNX Runtime

Photo by Sam Hojati on Unsplash
Photo by Sam Hojati on Unsplash
Open Neural Network Exchange (ONNX) is a model format defined by several major manufacturers. ONNX Runtime is a library that can execute ONNX models. It was developed by Microsoft. It supports multiple platforms, including Android.

Open Neural Network Exchange (ONNX) is a model format defined by several major manufacturers. ONNX Runtime is a library that can execute ONNX models. It was developed by Microsoft. It supports multiple platforms, including Android. This article will introduce how to use ONNX Runtime on Android by building a penguin detection app.

The complete code for this chapter can be found in .

Export YOLOv8 Models to ONNX Format

We will use the YOLOv8 model created in the following article. This model can be used to detect baby penguins in pictures or videos.

The YOLOv8 model is in PyTorch format, so we need to use the following command to convert it into ONNX format. After execution, it will output a best.onnx file. We rename it baby_penguin.onnx.

YOLOv8Example % .venv/bin/yolo mode=export model=./runs/train/weights/best.pt format=onnx simplify=True

Adding ONNX Runtime Dependencies into the Project

To execute ONNX Runtime on Android, we must first add it into the project as follows.

dependencies {
    implementation("com.microsoft.onnxruntime:onnxruntime-android:0.9.0")
    implementation("com.microsoft.onnxruntime:onnxruntime-extensions-android:0.9.0")
}

Loading Models

First, place baby_penguin.onnx in the assets folder of the project and load the model into an InputStream. Call OrtEnvironment.getEnvironment() to get an OrtEnvironment, then call OrtEnvironment.createSession() and pass in the model’s InputStream to create an OrtSession.

class AppViewModel : ViewModel() {
    private val _uiState = MutableStateFlow(AppUiState())
    val uiState: StateFlow<AppUiState> = _uiState.asStateFlow()

    private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
    private var session: OrtSession? = null

    fun load(assets: AssetManager) {
        viewModelScope.launch(Dispatchers.IO) {
            ortEnv.use {
                assets.open( "baby_penguin.onnx").use {
                    BufferedInputStream(it).use { bis ->
                        session = ortEnv.createSession(bis.readBytes())
                    }
                }
            }

            // Load the image
            assets.open("image.jpg").use {
                _uiState.value = AppUiState(Bitmap.createBitmap(BitmapFactory.decodeStream(it)))
            }
        }
    }
}

Executing Models

Before executing the model, we must first convert the image into an input format accepted by the model. The input image size of the baby penguin detection model is 640 x 640, so we scale the image to 640 x 640.

The default format of Android’s Bitmap is ARGB_8888, which means that a pixel is composed of 4 unsigned bytes, and each value is 0 ~ 255. The order is alpha, red, green, and blue. However, the format accepted by YOLOv8 is that a pixel consists of 3 floats, the order of which is blue, green, and red, and each value is 0 ~ 1. Therefore, we convert RGB to BGR in bitmapToFlatArray().

Then, we call OnnxTensor.createTensor() to convert the Bitmap into a 1 x 3 x 640 x 640 ONNX tensor.

Call OrtSession.run() to execute the model.

class AppViewModel : ViewModel() {
    companion object {
        private const val IMAGE_WIDTH = 640
        private const val IMAGE_HEIGHT = 640

        private const val BATCH_SIZE = 1
        private const val PIXEL_SIZE = 3
    }

    private val _uiState = MutableStateFlow(AppUiState())
    val uiState: StateFlow<AppUiState> = _uiState.asStateFlow()

    private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
    private var session: OrtSession? = null

    fun load(assets: AssetManager) {
        ...
    }

    fun infer(bitmap: Bitmap) {
        viewModelScope.launch(Dispatchers.Default) {
            val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false)
            val input = createTensor(scaledBitmap)
            val model = session ?: throw Exception("Model is not set")
            val inputName = model.inputNames.iterator().next()
            val output = model.run(Collections.singletonMap(inputName, input))
            ...
        }
    }

    private fun createTensor(bitmap: Bitmap): OnnxTensor {
        ortEnv.use {
            return OnnxTensor.createTensor(
                ortEnv,
                FloatBuffer.wrap(bitmapToFlatArray(bitmap)),
                longArrayOf(
                    BATCH_SIZE.toLong(),
                    PIXEL_SIZE.toLong(),
                    bitmap.width.toLong(),
                    bitmap.height.toLong(),
                )
            )
        }
    }

    private fun bitmapToFlatArray(bitmap: Bitmap): FloatArray {
        val input = FloatArray(BATCH_SIZE * PIXEL_SIZE * bitmap.width * bitmap.height)

        val pixels = bitmap.width * bitmap.height
        val bitmapArray = IntArray(pixels)
        bitmap.getPixels(bitmapArray, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
        for (i in 0..<bitmap.width) {
            for (j in 0..<bitmap.height) {
                val idx = bitmap.height * i + j
                val pixelValue = bitmapArray[idx]
                input[idx] = (pixelValue shr 16 and 0xFF) / 255f
                input[idx + pixels] = (pixelValue shr 8 and 0xFF) / 255f
                input[idx + pixels * 2] = (pixelValue and 0xFF) / 255f
            }
        }

        return input
    }
}

The output result of the model is a 1 x 5 x 8400 ONNX tensor, as shown below. The second dimension represents the bounding box, and its size is 5, which are the x, y, width, height, and confidence score of a bounding box respectively. The third dimension represents the number of bounding boxes, so the output result contains 8400 bounding boxes.

PyTorch Output Tensor, 1 x 5 x 8400.
PyTorch Output Tensor, 1 x 5 x 8400.

We can get its internal array in OrtSession.Result, as follows. The shape of this array is as shown in the figure above.

val outputArray = (result.get(0).value) as Array<*>
val outputs = outputArray[0] as Array<FloatArray>

NMS

When you use the YOLOv8 framework to execute the YOLOv8 model, you will get bounding boxes filtered by NMS. However, when we use ONNX Runtime to directly execute the YOLOv8 model, we get bounding boxes that have not been filtered by NMS, so the number of bounding boxes reaches 8400. Therefore, we must implement NMS ourselves to filter bounding boxes.

If you are not familiar with the NMS algorithm, please refer to the following article.

The following is the implementation of the NMS algorithm. Not only does it filter bounding boxes based on the NMS algorithm, it also converts bounding boxes from xywh format to xyxy format, that is, the upper left point and lower right point of the bounding box.

data class BoundingBox(val boundingBox: RectF, var score: Float, var clazz: Int)

object PostProcessor {
    private const val CONFIDENCE_THRESHOLD = 0.25f
    private const val IOU_THRESHOLD = 0.45f
    private const val MAX_NMS = 30000

    fun nms(
        result: OrtSession.Result,
        confidenceThreshold: Float = CONFIDENCE_THRESHOLD,
        iouThreshold: Float = IOU_THRESHOLD,
        maxNms: Int = MAX_NMS,
        scaleX: Float = 1f,
        scaleY: Float = 1f,
    ): List<BoundingBox> {
        val outputArray = (result.get(0).value) as Array<*>
        val outputs = outputArray[0] as Array<FloatArray>
        val results = mutableListOf<BoundingBox>()

        for (col in 0 until outputs[0].size) {
            var score = 0f
            var cls = 0
            for (row in 4 until outputs.size) {
                if (outputs			
> score) { score = outputs
cls = row } } cls -= 4 if (score > confidenceThreshold) { val x = outputs[0]
val y = outputs[1]
val w = outputs[2]
val h = outputs[3]
val left = x - w / 2 val top = y - h / 2 val right = x + w / 2 val bottom = y + h / 2 val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom) val boxes = BoundingBox(rect, score, cls) results.add(boxes) } } return nms(results, iouThreshold, maxNms) } private fun nms(boxes: List<BoundingBox>, iouThreshold: Float, limit: Int): List<BoundingBox> { val selected = mutableListOf<BoundingBox>() val sortedBoxes = boxes.sortedWith { o1, o2 -> o1.score.compareTo(o2.score) } val active = BooleanArray(sortedBoxes.size) Arrays.fill(active, true) var numActive = active.size var done = false var i = 0 while (i < sortedBoxes.size && !done) { if (active[i]) { val boxA = sortedBoxes[i] selected.add(boxA) if (selected.size >= limit) break for (j in i + 1 until sortedBoxes.size) { if (active[j]) { val boxB = sortedBoxes[j] if (calculateIou(boxA.boundingBox, boxB.boundingBox) > iouThreshold) { active[j] = false numActive -= 1 if (numActive <= 0) { done = true break } } } } } i++ } return selected } private fun calculateIou(a: RectF, b: RectF): Float { val areaA = (a.right - a.left) * (a.bottom - a.top) if (areaA <= 0.0) return 0.0f val areaB = (b.right - b.left) * (b.bottom - b.top) if (areaB <= 0.0) return 0.0f val intersectionMinX = max(a.left, b.left) val intersectionMinY = max(a.top, b.top) val intersectionMaxX = min(a.right, b.right) val intersectionMaxY = min(a.bottom, b.bottom) val intersectionArea = max( intersectionMaxY - intersectionMinY, 0.0f ) * max(intersectionMaxX - intersectionMinX, 0.0f) return intersectionArea / (areaA + areaB - intersectionArea) } }

With the NMS algorithm, after calling OrtSession.run(), we can call nms() to filter bounding boxes, as follows.

class AppViewModel : ViewModel() {
    companion object {
        private const val IMAGE_WIDTH = 640
        private const val IMAGE_HEIGHT = 640

        private const val BATCH_SIZE = 1
        private const val PIXEL_SIZE = 3
    }

    private val _uiState = MutableStateFlow(AppUiState())
    val uiState: StateFlow<AppUiState> = _uiState.asStateFlow()

    private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
    private var session: OrtSession? = null

    fun load(assets: AssetManager) {
        ...
    }

    fun infer(bitmap: Bitmap) {
        viewModelScope.launch(Dispatchers.Default) {
            val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false)
            val input = createTensor(scaledBitmap)
            val model = session ?: throw Exception("Model is not set")
            val inputName = model.inputNames.iterator().next()
            val output = model.run(Collections.singletonMap(inputName, input))
            val boxes = PostProcessor.nms(output)
            _uiState.value = _uiState.value.copy(boxes = boxes)
        }
    }
    ...
}

Conclusion

ONNX Runtime is quite simple to use. The more difficult part of this article is that after executing the YOLOv8 model, we must implement NMS ourselves. In addition to ONNX Runtime, we can also use PyTorch on Android to execute the model, please refer to the following article.

Reference

Leave a Reply

Your email address will not be published. Required fields are marked *

You May Also Like