PyTorch 是一個由 Meta 所開發的機器學習 library。YOLOv8 內部也是使用 Pytorch。除了 Python 環境之外,我們現在也可以在非 Python 的環境中使用 PyTorch。本文章將藉由建立一個企鵝偵測的 app 來介紹如何在 Android 上使用 PyTorch。
Table of Contents
將 YOLOv8 模型輸出成 TorchScript 格式
我們將使用以下文章中建立的 YOLOv8 模型。此模型可以用來偵測圖片或是影片中的企鵝寶寶。
YOLOv8 模型是 PyTorch 格式。因為我們無法在 Android 上直接執行 PyTorch 格式的模型,所以我們要利用以下的指令,將它轉換成 TorchScript 格式。執行後,它會輸出一個 best.torchscript 檔案。我們將它重新命名為 baby_penguin.torchscript。
YOLOv8Example % .venv/bin/yolo mode=export model=./runs/train/weights/best.pt format=torchscript optimize=True
引入 PyTorch 依賴至專案
要在 Android 上執行 TorchScript,我們必須先將 PyTorch Android Lite 引入至專案,如下。
dependencies { implementation("org.pytorch:pytorch_android_lite:2.1.0") implementation("org.pytorch:pytorch_android_torchvision_lite:2.1.0") }
載入模型
首先,先將 baby_penguin.torchscript 放置專案中的 assets
資料夾。然後,呼叫 LitePyTorchAndroid.loadModuleFromAsset()
來載入模型。
class AppViewModel : ViewModel() { private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var module: Module? = null fun load(assets: AssetManager) { viewModelScope.launch(Dispatchers.IO) { // Load the module module = LitePyTorchAndroid.loadModuleFromAsset(assets, "baby_penguin.torchscript") // Load the image here assets.open("image.jpg").use { _uiState.value = AppUiState(Bitmap.createBitmap(BitmapFactory.decodeStream(it))) } } } }
執行模型
在執行模型之前,我們必須要先將圖片轉換成模型接受的輸入格式。企鵝寶寶偵測模型的輸入圖片大小為 640 x 640,因此我們將圖片縮放成 640 x 640。然後,我們呼叫 TensorImageUtils.bitmapToFloat32Tensor()
將 Bitmap 轉換成一個 1 x 3 x 640 x 640 的 PyTorch tensor。
呼叫 Module.forward()
來執行模型。
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var module: Module? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val image = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = TensorImageUtils.bitmapToFloat32Tensor( image, floatArrayOf(0.0f, 0.0f, 0.0f), floatArrayOf(1.0f, 1.0f, 1.0f), ) val output = module?.forward(IValue.from(input))?.toTensor() ?: throw Exception("Module is not loaded") ... } } }
模型輸出的結果是一個 1 x 5 x 8400 的 PyTorch tensor,如下圖。第二維度代表 bounding box,其大小是 5,分別是一個 bounding box 的 x、y、寬、高、和 confidence score。第三維度表示 bounding box 的個數,所以該輸出的結果包含 8400 個 bounding boxes。
呼叫 Tensor.dataAsFloatArray
來將資料轉換成一個一維的 array,如下圖。這樣個排列非常不利於我們取得某個 bounding box 的 5 個數值。
因此我們裡用下面的程式碼,將一維的 array 轉換成 8400 x 5 的二維 array。這樣就方便我們取得任何一個 bounding box。
val array = tensor.dataAsFloatArray val rows = tensor.shape()[1].toInt() val cols = tensor.shape()[2].toInt() val outputs = Array(rows) { row -> array.sliceArray((row * cols) until ((row + 1) * cols)) }
NMS
當你用 YOLOv8 的 framework 來執行 YOLOv8 模型後,你會得到已用 NMS 過濾過的 bounding boxes。然而,當我們用 PyTorch 直接執行 YOLOv8 模型後,我們得到的是未用 NMS 過濾過的 bounding boxes,所以 bounding boxes 的數量才會達到 8400 個。因此,我們必須要自己實作 NMS 來過濾 bounding boxes。
如果你不熟悉 NMS 演算法的話,請參考以下文章。
以下是 NMS 演算法的實作。它不但會依據 NMS 演算法來過濾 bounding boxes,它還會將 bounding boxes 從 xywh
格式轉換成 xyxy
格式,也就是 bounding box 的左上點和右下點。
data class BoundingBox(val boundingBox: RectF, var score: Float, var clazz: Int) object PostProcessor { private const val CONFIDENCE_THRESHOLD = 0.25f private const val IOU_THRESHOLD = 0.45f private const val MAX_NMS = 30000 fun nms( tensor: Tensor, confidenceThreshold: Float = CONFIDENCE_THRESHOLD, iouThreshold: Float = IOU_THRESHOLD, maxNms: Int = MAX_NMS, scaleX: Float = 1f, scaleY: Float = 1f, ): List<BoundingBox> { val array = tensor.dataAsFloatArray val rows = tensor.shape()[1].toInt() val cols = tensor.shape()[2].toInt() val outputs = Array(rows) { row -> array.sliceArray((row * cols) until ((row + 1) * cols)) } val results = mutableListOf<BoundingBox>() for (col in 0 until outputs[0].size) { var score = 0f var cls = 0 for (row in 4 until outputs.size) { if (outputs> score) { score = outputscls = row } } cls -= 4 if (score > confidenceThreshold) { val x = outputs[0]val y = outputs[1]val w = outputs[2]val h = outputs[3]val left = x - w / 2 val top = y - h / 2 val right = x + w / 2 val bottom = y + h / 2 val rect = RectF(scaleX * left, top * scaleY, scaleX * right, scaleY * bottom) val result = BoundingBox(rect, score, cls) results.add(result) } } return nms(results, iouThreshold, maxNms) } private fun nms(boxes: List<BoundingBox>, iouThreshold: Float, limit: Int): List<BoundingBox> { val selected = mutableListOf<BoundingBox>() val sortedBoxes = boxes.sortedWith { o1, o2 -> o1.score.compareTo(o2.score) } val active = BooleanArray(sortedBoxes.size) Arrays.fill(active, true) var numActive = active.size var done = false var i = 0 while (i < sortedBoxes.size && !done) { if (active[i]) { val boxA = sortedBoxes[i] selected.add(boxA) if (selected.size >= limit) break for (j in i + 1 until sortedBoxes.size) { if (active[j]) { val boxB = sortedBoxes[j] if (calculateIou(boxA.boundingBox, boxB.boundingBox) > iouThreshold) { active[j] = false numActive -= 1 if (numActive <= 0) { done = true break } } } } } i++ } return selected } private fun calculateIou(a: RectF, b: RectF): Float { val areaA = (a.right - a.left) * (a.bottom - a.top) if (areaA <= 0.0) return 0.0f val areaB = (b.right - b.left) * (b.bottom - b.top) if (areaB <= 0.0) return 0.0f val intersectionMinX = max(a.left, b.left) val intersectionMinY = max(a.top, b.top) val intersectionMaxX = min(a.right, b.right) val intersectionMaxY = min(a.bottom, b.bottom) val intersectionArea = max( intersectionMaxY - intersectionMinY, 0.0f ) * max(intersectionMaxX - intersectionMinX, 0.0f) return intersectionArea / (areaA + areaB - intersectionArea) } }
有了 NMS 演算法後,我們在呼叫 Module.forward()
之後,就可以呼叫 nms()
來過濾 bounding boxes,如下。
class AppViewModel : ViewModel() { companion object { private const val IMAGE_WIDTH = 640 private const val IMAGE_HEIGHT = 640 } private val _uiState = MutableStateFlow(AppUiState()) val uiState: StateFlow<AppUiState> = _uiState.asStateFlow() private var module: Module? = null fun load(assets: AssetManager) { ... } fun infer(bitmap: Bitmap) { viewModelScope.launch(Dispatchers.Default) { val image = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, false) val input = TensorImageUtils.bitmapToFloat32Tensor( image, floatArrayOf(0.0f, 0.0f, 0.0f), floatArrayOf(1.0f, 1.0f, 1.0f), ) val output = module?.forward(IValue.from(input))?.toTensor() ?: throw Exception("Module is not loaded") val boxes = PostProcessor.nms(tensor = output) _uiState.value = _uiState.value.copy(boxes = boxes) } } }
結語
PyTorch Android Lite 使用起來相當地簡單。本文章中比較困難的地方是,在執行 YOLOv8 模型後,我們必須要自己實作 NMS。除了 PyTorch 之外,我們也可以在 Android 使用 ONNX Runtime 來執行模型,請參考以下文章。
參考
- Efficient mobile interpreter in Android and iOS, PyTorch.
- Android demo app – Object Detection, Github.