PyTorch is a machine learning library developed by Meta. YOLOv8 also uses Pytorch internally. In addition to Python environments, we can now use PyTorch in non-Python environments. This article will introduce how to use PyTorch on Android by building a penguin detection app.
The complete code for this chapter can be found in .
Table of Contents
Exporting YOLOv8 Models to TorchScript Format
We will use the YOLOv8 model created in the following article. This model can be used to detect baby penguins in pictures or videos.
YOLOv8 models are in PyTorch format. We cannot directly execute PyTorch format models on Android, so we need to use the following command to convert it into TorchScript format. After execution, it will output a best.torchscript file. We rename it baby_penguin.torchscript.
YOLOv8Example % .venv/bin/yolo mode=export model=./runs/train/weights/best.pt format=torchscript optimize=True
Adding PyTorch Dependencies into the Project
To execute TorchScript on Android, we must first add PyTorch Android Lite into the project, as follows.
dependencies { implementation("org.pytorch:pytorch_android_lite:2.1.0") implementation("org.pytorch:pytorch_android_torchvision_lite:2.1.0") }
Loading Models
First, place baby_penguin.torchscript in the assets
folder of the project. Then, call LitePyTorchAndroid.loadModuleFromAsset()
to load the model.
class AppViewModel : ViewModel() { private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var module: Module? = null fun load(assets: AssetManager) { viewModelScope.launch(Dispatchers.IO) { // Load the module module = LitePyTorchAndroid.loadModuleFromAsset(assets, "baby_penguin.torchscript") // Load the image here assets.open("image.jpg").use { _uiState.value = AppUiState(Bitmap.createBitmap(BitmapFactory.decodeStream(it))) } } } }
Executing Models
Before executing the model, we must first convert the image into an input format accepted by the model. The input image size of the baby penguin detection model is 640 x 640, so we scale the image to 640 x 640. Then, we call TensorImageUtils.bitmapToFloat32Tensor()
to convert the Bitmap to a 1 x 3 x 640 x 640 PyTorch tensor.
Call Module.forward()
to execute the model.
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var module: Module? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val image = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = TensorImageUtils.bitmapToFloat32Tensor( image, floatArrayOf(0.0f, 0.0f, 0.0f), floatArrayOf(1.0f, 1.0f, 1.0f), ) val output = module?.forward(IValue.from(input))?.toTensor() ?: throw Exception("Module is not loaded") ... } } }
The output result of the model is a 1 x 5 x 8400 PyTorch tensor, as shown below. The second dimension represents a bounding box, and its size is 5, which are the x, y, width, height, and confidence score of a bounding box respectively. The third dimension represents the number of bounding boxes, so the output result contains 8400 bounding boxes.
Call Tensor.dataAsFloatArray
to convert the data into a one-dimensional array, as shown below. This arrangement is very unfavorable for us to obtain the 5 values of a certain bounding box.
Therefore, we use the following code to convert the one-dimensional array into a two-dimensional array of 8400 x 5. This makes it convenient for us to obtain any bounding box.
val array = tensor.dataAsFloatArray val rows = tensor.shape()[1].toInt() val cols = tensor.shape()[2].toInt() val outputs = Array(rows) { row -> array.sliceArray((row * cols) until ((row + 1) * cols)) }
NMS
When you use the YOLOv8 framework to execute a YOLOv8 model, you will get bounding boxes filtered by NMS. However, when we use PyTorch to directly execute a YOLOv8 model, we get bounding boxes that are not filtered by NMS, so the number of bounding boxes reaches 8400. Therefore, we must implement NMS ourselves to filter bounding boxes.
If you are not familiar with the NMS algorithm, please refer to the following article.
The following is the implementation of the NMS algorithm. It will not only filter bounding boxes based on the NMS algorithm, it will also convert the bounding boxes from xywh
format to xyxy
format, which is the upper left point and lower right point of the bounding box.
data class BoundingBox(val boundingBox: RectF, var score: Float, var clazz: Int) object PostProcessor { private const val CONFIDENCE_THRESHOLD = 0.25f private const val IOU_THRESHOLD = 0.45f private const val MAX_NMS = 30000 fun nms( tensor: Tensor, confidenceThreshold: Float = CONFIDENCE_THRESHOLD, iouThreshold: Float = IOU_THRESHOLD, maxNms: Int = MAX_NMS, scaleX: Float = 1f, scaleY: Float = 1f, ): List<BoundingBox> { val array = tensor.dataAsFloatArray val rows = tensor.shape()[1].toInt() val cols = tensor.shape()[2].toInt() val outputs = Array(rows) { row -> array.sliceArray((row * cols) until ((row + 1) * cols)) } val results = mutableListOf<BoundingBox>() for (col in 0 until outputs[0].size) { var score = 0f var cls = 0 for (row in 4 until outputs.size) { if (outputs> score) { score = outputscls = row } } cls -= 4 if (score > confidenceThreshold) { val x = outputs[0]val y = outputs[1]val w = outputs[2]val h = outputs[3]val left = x - w / 2 val top = y - h / 2 val right = x + w / 2 val bottom = y + h / 2 val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom) val result = BoundingBox(rect, score, cls) results.add(result) } } return nms(results, iouThreshold, maxNms) } private fun nms(boxes: List<BoundingBox>, iouThreshold: Float, limit: Int): List<BoundingBox> { val selected = mutableListOf<BoundingBox>() val sortedBoxes = boxes.sortedWith { o1, o2 -> o1.score.compareTo(o2.score) } val active = BooleanArray(sortedBoxes.size) Arrays.fill(active, true) var numActive = active.size var done = false var i = 0 while (i < sortedBoxes.size && !done) { if (active[i]) { val boxA = sortedBoxes[i] selected.add(boxA) if (selected.size >= limit) break for (j in i + 1 until sortedBoxes.size) { if (active[j]) { val boxB = sortedBoxes[j] if (calculateIou(boxA.boundingBox, boxB.boundingBox) > iouThreshold) { active[j] = false numActive -= 1 if (numActive <= 0) { done = true break } } } } } i++ } return selected } private fun calculateIou(a: RectF, b: RectF): Float { val areaA = (a.right - a.left) * (a.bottom - a.top) if (areaA <= 0.0) return 0.0f val areaB = (b.right - b.left) * (b.bottom - b.top) if (areaB <= 0.0) return 0.0f val intersectionMinX = max(a.left, b.left) val intersectionMinY = max(a.top, b.top) val intersectionMaxX = min(a.right, b.right) val intersectionMaxY = min(a.bottom, b.bottom) val intersectionArea = max( intersectionMaxY - intersectionMinY, 0.0f ) * max(intersectionMaxX - intersectionMinX, 0.0f) return intersectionArea / (areaA + areaB - intersectionArea) } }
With the NMS algorithm, after calling Module.forward()
, we can call nms()
to filter bounding boxes, as follows.
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var module: Module? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val image = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = TensorImageUtils.bitmapToFloat32Tensor( image, floatArrayOf(0.0f, 0.0f, 0.0f), floatArrayOf(1.0f, 1.0f, 1.0f), ) val output = module?.forward(IValue.from(input))?.toTensor() ?: throw Exception("Module is not loaded") val boxes = PostProcessor.nms(tensor = output) _uiState.value = _uiState.value.copy(boxes = boxes) } } }
Conclusion
PyTorch Android Lite is quite simple to use. The more difficult part of this article is that after executing the YOLOv8 model, we must implement NMS ourselves. In addition to PyTorch, we can also use ONNX Runtime on Android to execute the model, please refer to the following article.
Reference
- Efficient mobile interpreter in Android and iOS, PyTorch.
- Android demo app – Object Detection, Github.