Open Neural Network Exchange (ONNX) is a model format defined by several major manufacturers. ONNX Runtime is a library that can execute ONNX models. It was developed by Microsoft. It supports multiple platforms, including Android. This article will introduce how to use ONNX Runtime on Android by building a penguin detection app.
The complete code for this chapter can be found in .
Table of Contents
Export YOLOv8 Models to ONNX Format
We will use the YOLOv8 model created in the following article. This model can be used to detect baby penguins in pictures or videos.
The YOLOv8 model is in PyTorch format, so we need to use the following command to convert it into ONNX format. After execution, it will output a best.onnx file. We rename it baby_penguin.onnx.
YOLOv8Example % .venv/bin/yolo mode=export model=./runs/train/weights/best.pt format=onnx simplify=True
Adding ONNX Runtime Dependencies into the Project
To execute ONNX Runtime on Android, we must first add it into the project as follows.
dependencies { implementation("com.microsoft.onnxruntime:onnxruntime-android:0.9.0") implementation("com.microsoft.onnxruntime:onnxruntime-extensions-android:0.9.0") }
Loading Models
First, place baby_penguin.onnx in the assets
folder of the project and load the model into an InputStream
. Call OrtEnvironment.getEnvironment() to get an OrtEnvironment, then call OrtEnvironment.createSession() and pass in the model’s InputStream
to create an OrtSession
.
class AppViewModel : ViewModel() { private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private var session: OrtSession? = null fun load(assets: AssetManager) { viewModelScope.launch(Dispatchers.IO) { ortEnv.use { assets.open( "baby_penguin.onnx").use { BufferedInputStream(it).use { bis -> session = ortEnv.createSession(bis.readBytes()) } } } // Load the image assets.open("image.jpg").use { _uiState.value = AppUiState(Bitmap.createBitmap(BitmapFactory.decodeStream(it))) } } } }
Executing Models
Before executing the model, we must first convert the image into an input format accepted by the model. The input image size of the baby penguin detection model is 640 x 640, so we scale the image to 640 x 640.
The default format of Android’s Bitmap
is ARGB_8888
, which means that a pixel is composed of 4 unsigned bytes, and each value is 0 ~ 255. The order is alpha, red, green, and blue. However, the format accepted by YOLOv8 is that a pixel consists of 3 floats, the order of which is blue, green, and red, and each value is 0 ~ 1. Therefore, we convert RGB to BGR in bitmapToFlatArray()
.
Then, we call OnnxTensor.createTensor()
to convert the Bitmap into a 1 x 3 x 640 x 640 ONNX tensor.
Call OrtSession.run()
to execute the model.
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 private const val BATCH_SIZE = 1 private const val PIXEL_SIZE = 3 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private var session: OrtSession? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = createTensor(scaledBitmap) val model = session ?: throw Exception("Model is not set") val inputName = model.inputNames.iterator().next() val output = model.run(Collections.singletonMap(inputName, input)) ... } } private fun createTensor(bitmap: Bitmap): OnnxTensor { ortEnv.use { return OnnxTensor.createTensor( ortEnv, FloatBuffer.wrap(bitmapToFlatArray(bitmap)), longArrayOf( BATCH_SIZE.toLong(), PIXEL_SIZE.toLong(), bitmap.width.toLong(), bitmap.height.toLong(), ) ) } } private fun bitmapToFlatArray(bitmap: Bitmap): FloatArray { val input = FloatArray(BATCH_SIZE * PIXEL_SIZE * bitmap.width * bitmap.height) val pixels = bitmap.width * bitmap.height val bitmapArray = IntArray(pixels) bitmap.getPixels(bitmapArray, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height) for (i in 0..<bitmap.width) { for (j in 0..<bitmap.height) { val idx = bitmap.height * i + j val pixelValue = bitmapArray[idx] input[idx] = (pixelValue shr 16 and 0xFF) / 255f input[idx + pixels] = (pixelValue shr 8 and 0xFF) / 255f input[idx + pixels * 2] = (pixelValue and 0xFF) / 255f } } return input } }
The output result of the model is a 1 x 5 x 8400 ONNX tensor, as shown below. The second dimension represents the bounding box, and its size is 5, which are the x, y, width, height, and confidence score of a bounding box respectively. The third dimension represents the number of bounding boxes, so the output result contains 8400 bounding boxes.
We can get its internal array in OrtSession.Result
, as follows. The shape of this array is as shown in the figure above.
val outputArray = (result.get(0).value) as Array<*> val outputs = outputArray[0] as Array<FloatArray>
NMS
When you use the YOLOv8 framework to execute the YOLOv8 model, you will get bounding boxes filtered by NMS. However, when we use ONNX Runtime to directly execute the YOLOv8 model, we get bounding boxes that have not been filtered by NMS, so the number of bounding boxes reaches 8400. Therefore, we must implement NMS ourselves to filter bounding boxes.
If you are not familiar with the NMS algorithm, please refer to the following article.
The following is the implementation of the NMS algorithm. Not only does it filter bounding boxes based on the NMS algorithm, it also converts bounding boxes from xywh
format to xyxy
format, that is, the upper left point and lower right point of the bounding box.
data class BoundingBox(val boundingBox: RectF, var score: Float, var clazz: Int) object PostProcessor { private const val CONFIDENCE_THRESHOLD = 0.25f private const val IOU_THRESHOLD = 0.45f private const val MAX_NMS = 30000 fun nms( result: OrtSession.Result, confidenceThreshold: Float = CONFIDENCE_THRESHOLD, iouThreshold: Float = IOU_THRESHOLD, maxNms: Int = MAX_NMS, scaleX: Float = 1f, scaleY: Float = 1f, ): List<BoundingBox> { val outputArray = (result.get(0).value) as Array<*> val outputs = outputArray[0] as Array<FloatArray> val results = mutableListOf<BoundingBox>() for (col in 0 until outputs[0].size) { var score = 0f var cls = 0 for (row in 4 until outputs.size) { if (outputs> score) { score = outputscls = row } } cls -= 4 if (score > confidenceThreshold) { val x = outputs[0]val y = outputs[1]val w = outputs[2]val h = outputs[3]val left = x - w / 2 val top = y - h / 2 val right = x + w / 2 val bottom = y + h / 2 val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom) val boxes = BoundingBox(rect, score, cls) results.add(boxes) } } return nms(results, iouThreshold, maxNms) } private fun nms(boxes: List<BoundingBox>, iouThreshold: Float, limit: Int): List<BoundingBox> { val selected = mutableListOf<BoundingBox>() val sortedBoxes = boxes.sortedWith { o1, o2 -> o1.score.compareTo(o2.score) } val active = BooleanArray(sortedBoxes.size) Arrays.fill(active, true) var numActive = active.size var done = false var i = 0 while (i < sortedBoxes.size && !done) { if (active[i]) { val boxA = sortedBoxes[i] selected.add(boxA) if (selected.size >= limit) break for (j in i + 1 until sortedBoxes.size) { if (active[j]) { val boxB = sortedBoxes[j] if (calculateIou(boxA.boundingBox, boxB.boundingBox) > iouThreshold) { active[j] = false numActive -= 1 if (numActive <= 0) { done = true break } } } } } i++ } return selected } private fun calculateIou(a: RectF, b: RectF): Float { val areaA = (a.right - a.left) * (a.bottom - a.top) if (areaA <= 0.0) return 0.0f val areaB = (b.right - b.left) * (b.bottom - b.top) if (areaB <= 0.0) return 0.0f val intersectionMinX = max(a.left, b.left) val intersectionMinY = max(a.top, b.top) val intersectionMaxX = min(a.right, b.right) val intersectionMaxY = min(a.bottom, b.bottom) val intersectionArea = max( intersectionMaxY - intersectionMinY, 0.0f ) * max(intersectionMaxX - intersectionMinX, 0.0f) return intersectionArea / (areaA + areaB - intersectionArea) } }
With the NMS algorithm, after calling OrtSession.run()
, we can call nms()
to filter bounding boxes, as follows.
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 private const val BATCH_SIZE = 1 private const val PIXEL_SIZE = 3 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private var session: OrtSession? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val scaledBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = createTensor(scaledBitmap) val model = session ?: throw Exception("Model is not set") val inputName = model.inputNames.iterator().next() val output = model.run(Collections.singletonMap(inputName, input)) val boxes = PostProcessor.nms(output) _uiState.value = _uiState.value.copy(boxes = boxes) } } ... }
Conclusion
ONNX Runtime is quite simple to use. The more difficult part of this article is that after executing the YOLOv8 model, we must implement NMS ourselves. In addition to ONNX Runtime, we can also use PyTorch on Android to execute the model, please refer to the following article.
Reference
- Object detection and pose estimation on mobile with YOLOv8, ONNX RUNTIME.
- ONNX RUNTIME Inference Example – Object Detection, Github.