YOLOv8 pose estimation 안드로이드 -2

알로에·2023년 4월 8일
6

YOLOv8-Pose Estimation

목록 보기
2/2

이전 글에서 yolov8n pose 모델을 안드로이드에 적용하고 추론의 결과까지 받아왔다. 이제 그 배열 값에 대해 구조를 살펴보고 적절히 처리해서 화면에 표현하려 한다.

✔ 1. 출력에 대한 고찰

이 부분은 내가 고민한 부분에 대한 글이다. 코드를 짜는데 중요하지 않으니 궁금하지 않다면 ✔ 2. 번으로 넘어가도 무방하다.

우선 yolov8 pose 출력에 대한 글이 없다...
내가 알 수 있는 방법은 model의 출력을 보고 유추할 뿐이였다.

아래는 https://netron.app/ 에서 확인한 모델 출력 부분이다.

[1 56 8400] 인걸 확인하고 8400개의 후보군이 있다는 건, pose model도 NMS처리가 필요하다는 것을 알 수 있었다. 다만, pose model이 가지고 있는 56개의 요소는 아직 알 지 못했다.

아래 사진은 모델의 요약본이다.

56개의 요소 중에서 kpt_shape가 있을 것 같았다. 17 * 3 의 형태를 보아하니 56개의 요소 중 키포인트가 51개를 차지하고 있을 것 같았다. (확신 하지 못했음. v8 공식 사이트에 글이 없음.) 그럼 만약 56개 중 51개가 키포인트라면 키 포인트는 뭐고 그 외 5개의 요소는 무엇인지 궁금했다.

구글에 pose keypoint를 검색하면

이런 사진들이 나오게 된다. 그리고 좀 더 글을 찾아보게 되면 key point의 형태는 주로 x y confidence (확률 값) 임을 알 수 있다.
즉 yolov8 pose model이 가지는 56개의 요소 중 51개는 17개의 관절이며, 각각의 [x y conf] 값 임을 알 수 있었다. (물론 확정은 아님.)
그럼 나머지 5개의 요소는 알 지 못한채, 출력을 디버깅해서 56개의 값을 유추하기로 했다.
56개의 출력은 아래와 같은 형태였다.

0 ~ 4 ==> 0 ~ 639 사이의 큰 값
5 ==> 0 ~ 1 사이의 작은 값
6, 7 ==> 0 ~ 639 사이의 큰 값
8 ==> 0 ~ 1 사이의 작은 값
9, 10 ==> 0 ~ 639 사이의 큰 값
11 ==> 0 ~ 1 사이의 작은 값
... 반복

위에서 예측한대로 0 ~ 5 는 알지 못하지만 6 ~ 56의 요소는
각 key point(관절)에 대한 x y conf 값임을 알 수 있었다.
x,y는 입력 사진이 640이니 639 까지의 큰 값이 나오며,
confidence 확률은 0 ~ 1 사이이므로 작은 값이 나오게 된다.

여기까지 알았을 때, 처음엔 행복했지만 시간이 지날수록 머리속에 물음표만 가득해졌다.

  1. 0 ~ 4 는 무언가의 좌표값인 것 같고 5번의 거기에 해당하는 확률값인 것같은데?
  2. NMS를 하려면, 어떤 값의 conf를 기준으로 중복되는 IoU를 제거해야하는데, 그럼 각 각의 keypoint 마다 IoU를 제거해야하나?
  3. 그럼 keypoint의 xy좌표만으로 box를 그릴 수 있나?

등등 너무 막막했다. 그렇게 고민하는 중에 ultralytics의 pose에 대한 predict 예제를 해보기로 했다.
https://docs.ultralytics.com/tasks/pose/#predict <- 해당 사이트

아래에 Predict 에 대한 자세한 설명이 있는 페이지에 넘어가면,

이런 입력변수로 넣을 수 있는 요소를 알 수 있다.

입력 변수 중 'save'는 출력 결과를 저장하는 사진임을 확인하고,
파이썬 파일을 하나 만들어서 테스트를 했다.

from ultralytics import YOLO

model = YOLO(model='yolov8n-pose.pt')

model.predict('https://ultralytics.com/images/bus.jpg',save=True)

그리고 저장된 사진은 아래와 같았다.

사진을 보는순간 5개의 값이 뭘 의미하는지 바로 알았다.
pose estimation도 결국 object-detection을 하고, 각 관절을 측정하는 것이였다.
0 ~ 5 까지의 값은 각각 사람에 대한 바운딩 박스의 x y w h conf 값이였다. 정리하면 아래의 형태와 같다.

출력 [1 56 8400] 은 8400개의 후보군이 있으며, 각각 56개의 요소는
0 ~ 3은 사람의 바운딩 박스, 4는 그 박스의 확률이며
5 ~ 55는 사람의 keypoints에 대한 x y 확률값이 순서대로 있는 형태임을 알 수 있었다. 이때 까지만 해도 이 각각 keypoints 가 무엇인지 몰라서 나중에 kpts 0만 화면에 보여주고, 1만 화면에 보여주는 식으로 하나씩 확인해봤고 아래와 같다.

0번 == 코
1번 == 오른쪽 눈
2번 == 왼쪽 눈
3번 == 오른쪽 귀
4번 == 왼쪽 귀
5번 == 오른쪽 어깨
6번 == 왼쪽 어깨
7번 == 오른쪽 팔꿈치
8번 == 왼쪽 팔꿈치
9번 == 오른쪽 손목
10번 == 왼쪽 손목
11번 == 오른쪽 골반
12번 == 왼쪽 골반
13번 == 오른쪽 무릎
14번 == 왼쪽 무릎
15번 == 오른쪽 발
16번 == 왼쪽 발

✔ 2. 출력 정리

[1 56 8400]은 아래와 같은 형태이다.

각 keypoints는 순서대로 코, 오른쪽 눈, 왼쪽 눈, 오른쪽 귀, 왼쪽 귀, 오른쪽 어깨, 왼쪽 어깨, 오른쪽 팔꿈치, 왼쪽 팔꿈치, 오른쪽 손목, 왼쪽 손목, 오른쪽 골반, 왼쪽 골반, 오른쪽 무릎, 왼쪽 무릎, 오른쪽 발, 왼쪽 발 순서이다.

✔ 3. conf 임계값 설정

출력을 알았으니 이제 다시 안드로이드 코드로 돌아가서 출력에 대해 conf에 의한 제거, nms 처리를 하고 화면에 표출하면 된다.

사람의 확률(요소 5번)이 아주 낮은 경우에도 화면에 표출하면, 수많은 사람에 대한 바운딩 박스와 keypoint들이 화면에 보여질 것이다.
따라서 사람에 대한 임계값을 넘기지 못한 경우는 미리 제거한다. 이후 겹치는 박스를 제거하면, 최종적으로 화면에 보여질 요소가 나오게 된다.

DataProcess의 내부 메서드 
fun outputsToNPMSPredictions(outputs: Array<*>): ArrayList<FloatArray> {
        val confidenceThreshold = 0.4f
        val rows: Int
        val cols: Int
        val results = ArrayList<FloatArray>()

        (outputs[0] as Array<*>).also {
            rows = it.size
            cols = (it[0] as FloatArray).size
        }

        //배열 형태를 [56 8400] -> [8400 56] 으로 변환
        val output = Array(cols) { FloatArray(rows) }
        for (i in 0 until rows) {
            for (j in 0 until cols) {
                output[j][i] = (((outputs[0] as Array<*>)[i]) as FloatArray)[j]
            }
        }


        for (i in 0 until cols) {
            // 바운딩 박스의 특정 확률을 넘긴 경우에만 xywh -> xy xy 형태로 변환 후 nms 처리
            if (output[i][4] > confidenceThreshold) {
                val xPos = output[i][0]
                val yPos = output[i][1]
                val width = output[i][2]
                val height = output[i][3]

                val x1 = max(xPos - width / 2f, 0f)
                val x2 = min(xPos + width / 2f, INPUT_SIZE - 1f)
                val y1 = max(yPos - height / 2f, 0f)
                val y2 = min(yPos + height / 2f, INPUT_SIZE - 1f)

                output[i][0] = x1
                output[i][1] = y1
                output[i][2] = x2
                output[i][3] = y2

                results.add(output[i])
            }
        }
        return nms(results)
    }

모델 출력을 받아서, 임계값 (40%)를 설정하고, 사람에 대한 확률값이 40%를 넘기지 못하면 제거하고, 나머지는 xywh -> xyxy형태로 변환한 뒤에 nms처리를 하여 반환한다.

✔ 4. NMS (비 최대 억제)

출력에 대한 NMS (비 최대 억제) 처리는 사람의 바운딩 박스와 확률값으로 처리한다. 사람의 확률이 가장 큰 바운딩 박스에 대해 그 외 바운딩 박스 중에서 겹치는 박스의 비율이 50%를 넘기면 해당 박스를 제거한다. (IoU thresholds = 0.5f)

private fun nms(results: ArrayList<FloatArray>): ArrayList<FloatArray> {
        val list = ArrayList<FloatArray>()
        //results 안에 있는 conf 값 중에서 제일 높은 애를 기준으로 NMS 가 겹치는 애들을 제거
        val pq = PriorityQueue<FloatArray>(5) { o1, o2 ->
            o1[4].compareTo(o2[4])
        }

        pq.addAll(results)

        while (pq.isNotEmpty()) {
            // 큐 안에 속한 최대 확률값을 가진 FloatArray 저장
            val detections = pq.toTypedArray()
            val max = detections[0]
            list.add(max)
            pq.clear()

            // 교집합 비율 확인하고 50% 넘기면 제거
            for (k in 1 until detections.size) {
                val detection = detections[k]
                val rectF = RectF(detection[0], detection[1], detection[2], detection[3])
                val maxRectF = RectF(max[0], max[1], max[2], max[3])
                val iouThreshold = 0.5f
                if (boxIOU(maxRectF, rectF) < iouThreshold) {
                    pq.add(detection)
                }
            }
        }
        return list
    }


    // 겹치는 비율 (교집합/합집합)
    private fun boxIOU(a: RectF, b: RectF): Float {
        return boxIntersection(a, b) / boxUnion(a, b)
    }

    //교집합
    private fun boxIntersection(a: RectF, b: RectF): Float {
        // x1, x2 == 각 rect 객체의 중심 x or y값, w1, w2 == 각 rect 객체의 넓이 or 높이
        val w = overlap(
            (a.left + a.right) / 2f, a.right - a.left,
            (b.left + b.right) / 2f, b.right - b.left
        )

        val h = overlap(
            (a.top + a.bottom) / 2f, a.bottom - a.top,
            (b.top + b.bottom) / 2f, b.bottom - b.top
        )
        return if (w < 0 || h < 0) 0f else w * h
    }

    //합집합
    private fun boxUnion(a: RectF, b: RectF): Float {
        val i = boxIntersection(a, b)
        return (a.right - a.left) * (a.bottom - a.top) + (b.right - b.left) * (b.bottom - b.top) - i
    }

    //서로 겹치는 길이
    private fun overlap(x1: Float, w1: Float, x2: Float, w2: Float): Float {
        val l1 = x1 - w1 / 2
        val l2 = x2 - w2 / 2
        val left = max(l1, l2)
        val r1 = x1 + w1 / 2
        val r2 = x2 + w2 / 2
        val right = min(r1, r2)
        return right - left
    }

이렇게 정의한 코드를 메인 액티비티에서 추가하면 된다.

//이전 글에서 작성한 메인 액티비티의 imageProcess 메서드에 추가 
private fun imageProcess(imageProxy: ImageProxy) {

        val bitmap = dataProcess.imageToBitmap(imageProxy)
        val buffer = dataProcess.bitmapToFloatBuffer(bitmap)
        val inputName = session.inputNames.iterator().next()
        //모델의 요구 입력값 [1 3 640 640] [배치 사이즈, 픽셀(RGB), 너비, 높이], 모델마다 크기는 다를 수 있음.
        val shape = longArrayOf(
            DataProcess.BATCH_SIZE.toLong(),
            DataProcess.PIXEL_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong()
        )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, buffer, shape)
        val resultTensor = session.run(Collections.singletonMap(inputName, inputTensor))
        val outputs = resultTensor[0].value as Array<*>
        //추가된 부분
        val results = dataProcess.outputsToNPMSPredictions(outputs)
    }

이렇게 완성된 최종 결과를 화면에 적절히 표출하면 된다.

✔ 5. 화면에 표출

  1. 화면에 표출할 View를 상속하는 poseView 객체를 생성한다.
class PoseView(context: Context, attributeSet: AttributeSet) : View(context, attributeSet) {
}
  1. activity_main.xml에 view를 추가한다.
<com.example.yolov8n_pose.PoseView
        android:id="@+id/poseView"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        app:layout_constraintBottom_toTopOf="parent"
        app:layout_constraintEnd_toStartOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />

아래 사진과 같이 하면 된다. 이름은 com.example.~~~.PoseView지만,
프로젝트 명에 따라 PoseView 앞의 부분은 알아서 수정하면 된다.

  1. 메인 액티비티 클래스에서 해당 뷰를 추가한다.
private lateinit var poseView: PoseView

override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        previewView = findViewById(R.id.previewView)
        //poseview 추가 
        poseView = findViewById(R.id.poseView)

        //자동꺼짐 해제
        window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)

        //권한 허용
        setPermissions()

        //모델 불러오기
        load()

        //카메라 켜기
        setCamera()
    }
  1. poseView 클래스 내부
class PoseView(context: Context, attributeSet: AttributeSet) : View(context, attributeSet) {

    private var list: ArrayList<FloatArray>? = null
    private val pointPaint = Paint().apply {
        color = Color.RED
        style = Paint.Style.STROKE
        strokeWidth = 10f
    }
    private val linePaint = Paint().apply {
        color = Color.YELLOW
        style = Paint.Style.FILL
        strokeWidth = 5f
    }

    fun setList(list: ArrayList<FloatArray>) {
        this.list = list
    }
    
     @SuppressLint("DrawAllocation")
    override fun onDraw(canvas: Canvas?) {
        super.onDraw(canvas)
    }
}

list == 위에서 nms 처리가 완료된, 화면에 보여줄 배열값을 전달받는 객체이다. setList 메서드에서 해당 list를 받아올 예정이다.
화면에는 각 포인트에 대한 점, 그리고 점끼리 잇는 선이있다. 색상은 중요하지 않고, 나는 점은 빨간색, 선은 노란색으로 표현했다.
아래 onDraw메서드에 함수를 추가하여 점과 선에 대한 내용을 캔버스에 그리면 완성이다.

  1. drawPointsAndLines 메서드 정의
  @SuppressLint("DrawAllocation")
    override fun onDraw(canvas: Canvas?) {
        //점, 선 그리기
        drawPointsAndLines(canvas)
        super.onDraw(canvas)
    }

위에서 적은 onDraw 메서드에 drawPointsAndLines 메서드를 추가한다.

 private fun drawPointsAndLines(canvas: Canvas?) {
        val scaleX = width / DataProcess.INPUT_SIZE.toFloat()
        val scaleY = scaleX * 9f / 16f
        val realY = width * 9f / 16f
        val diffY = realY - height

        val kPointsThreshold = 0.35f
            list?.forEach {
            val points = FloatArray(34)
            for ((a, i) in (points.indices step 2).withIndex()) {
                if (it[i + 7 + a] > kPointsThreshold) {
                    points[i] = it[i + 5 + a] * scaleX
                    points[i + 1] = it[i + 6 + a] * scaleY - (diffY / 2f)
                }
            }
            drawPoint(canvas, points)
            drawLines(canvas, points)
        }
    }

    private fun drawPoint(canvas: Canvas?, points: FloatArray) {
        for (i in points.indices step 2) {
            val xPos = points[i]
            val yPos = points[i + 1]
            if (xPos > 0 && yPos > 0) {
                canvas?.drawPoint(xPos, yPos, pointPaint)
            }
        }
    }

    private fun drawLines(canvas: Canvas?, points: FloatArray) {
        // 점과 점사이에 직선 그리기
        // keypoint 순서
        // 0번 == 코
        // 1번 == 오른쪽 눈
        // 2번 == 왼쪽 눈
        // 3번 == 오른쪽 귀
        // 4번 == 왼쪽 귀
        // 5번 == 오른쪽 어깨
        // 6번 == 왼쪽 어깨
        // 7번 == 오른쪽 팔꿈치
        // 8번 == 왼쪽 팔꿈치
        // 9번 == 오른쪽 손목
        // 10번 == 왼쪽 손목
        // 11번 == 오른쪽 골반
        // 12번 == 왼쪽 골반
        // 13번 == 오른쪽 무릎
        // 14번 == 왼쪽 무릎
        // 15번 == 오른쪽 발
        // 16번 == 왼쪽 발

        // 코, 오른쪽 눈 연결
        var startX = points[0]
        var startY = points[1]
        var stopX = points[2]
        var stopY = points[3]
        drawLine(startX, startY, stopX, stopY, canvas)
        // 코, 왼쪽 눈 연결
        startX = points[0]
        startY = points[1]
        stopX = points[4]
        stopY = points[5]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 눈 귀 연결
        startX = points[2]
        startY = points[3]
        stopX = points[8]
        stopY = points[9]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 눈 귀 연결
        startX = points[4]
        startY = points[5]
        stopX = points[8]
        stopY = points[9]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 귀 어깨 연결
        startX = points[6]
        startY = points[7]
        stopX = points[10]
        stopY = points[11]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 귀 어깨 연결
        startX = points[8]
        startY = points[9]
        stopX = points[12]
        stopY = points[13]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 어깨 팔꿈치 연결
        startX = points[10]
        startY = points[11]
        stopX = points[14]
        stopY = points[15]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 어깨 팔꿈치 연결
        startX = points[12]
        startY = points[13]
        stopX = points[16]
        stopY = points[17]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 어깨 골반 연결
        startX = points[10]
        startY = points[11]
        stopX = points[22]
        stopY = points[23]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 어깨 골반 연결
        startX = points[12]
        startY = points[13]
        stopX = points[24]
        stopY = points[25]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 팔꿈치 손목 연결
        startX = points[14]
        startY = points[15]
        stopX = points[18]
        stopY = points[19]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 팔꿈치 손목 연결
        startX = points[16]
        startY = points[17]
        stopX = points[20]
        stopY = points[21]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 골반 무릎 연결
        startX = points[22]
        startY = points[23]
        stopX = points[26]
        stopY = points[27]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 골반 무릎 연결
        startX = points[24]
        startY = points[25]
        stopX = points[28]
        stopY = points[29]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 무릎 발 연결
        startX = points[26]
        startY = points[27]
        stopX = points[30]
        stopY = points[31]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 무릎 발 연결
        startX = points[28]
        startY = points[29]
        stopX = points[32]
        stopY = points[33]
        drawLine(startX, startY, stopX, stopY, canvas)
        //어깨 좌우 연결
        startX = points[10]
        startY = points[11]
        stopX = points[12]
        stopY = points[13]
        drawLine(startX, startY, stopX, stopY, canvas)
        //골반 좌우 연결
        startX = points[22]
        startY = points[23]
        stopX = points[24]
        stopY = points[25]
        drawLine(startX, startY, stopX, stopY, canvas)
    }

    private fun drawLine(
        startX: Float,
        startY: Float,
        stopX: Float,
        stopY: Float,
        canvas: Canvas?
    ) {
        if (startX > 0 && startY > 0 && stopX > 0 && stopY > 0) {
            canvas?.drawLine(startX, startY, stopX, stopY, linePaint)
        }
    }

drawPointsAndLines 메서드는 아래 역할을 한다.

  1. 화면에 비율에 맞게 점의 위치를 변환시킨다. 모델의 출력에 대한 point 좌표값은 [640 640] 화면에 대한 좌표값이다. 따라서 실제 화면의 가로 세로 비율에 맞게 적절히 좌표값을 변환시킨다.
    가로에 비해 세로는 왜 저렇게 좌표이동을 하는지 이유를 이해하려면, cameraX라이브러리의 FILL_CENTER를 이해해야 한다. 아마 다음 글이 될 것 같다.

  2. 각 point 에 대한 확률값에 제한을 두어서 해당 확률값을 넘기지 못하면 화면에 표출하지 않을 예정이다. 예를 들어, 사람에 대한 확률 값은 80%지만 사람의 코에 대한 확률 값은 5%라면, 코에 대한 좌표값은 화면에 표출하지 않겠다는 코드이다.

drawPoint 메서드는 좌표값이 0,0이 아니라면, 캔버스에 점에 대한 좌표값을 찍는다.
drawLines 메서드는 캔버스에 점끼리 선을 잇는 코드이다.
drawLine 메서드에서 좌표값이 0,0이 아닌 경우에만 연결하는 코드이다.

0,0인 경우에 그림을 그리지 않는 이유

val points = FloatArray(34)

이런식으로 선언하고, keypoints의 확률값에 대해 일정 값을 넘기지 못한다면, points 배열에 keypoints의 좌표값은 담기지 못한다. 그러면 디폴트값인 0이 담기게 되는데, 이 경우 화면에 디폴트 값인 0,0을 담지 않기 위해 좌표 값이 0,0인 경우를 무시하는 코드를 추가한 것이다.

  1. 메인 액티비티에서 화면을 불러오는 코드를 추가한다.
 private fun imageProcess(imageProxy: ImageProxy) {

        val bitmap = dataProcess.imageToBitmap(imageProxy)
        val buffer = dataProcess.bitmapToFloatBuffer(bitmap)
        val inputName = session.inputNames.iterator().next()
        //모델의 요구 입력값 [1 3 640 640] [배치 사이즈, 픽셀(RGB), 너비, 높이], 모델마다 크기는 다를 수 있음.
        val shape = longArrayOf(
            DataProcess.BATCH_SIZE.toLong(),
            DataProcess.PIXEL_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong()
        )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, buffer, shape)
        val resultTensor = session.run(Collections.singletonMap(inputName, inputTensor))
        val outputs = resultTensor[0].value as Array<*>
        val results = dataProcess.outputsToNPMSPredictions(outputs)
		//추가한 부분 
        poseView.setList(results)
        poseView.invalidate()
    }

출력 결과는 아래와 같다.

실시간으로 하다보니 손이 흔들려서 조금씩 변하기는 한다...
아래는 assets에 사진을 넣고 사진에 대한 추론을 했을 경우에 대한 사진이다.

object detection과 속도 차이도 별로 안나서, 실시간으로 해도 무방할 듯 하다.

아래는 전체 코드이다.

//메인 액티비티
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import android.content.pm.PackageManager
import android.os.Bundle
import android.view.WindowManager
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import androidx.camera.core.*
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.view.PreviewView
import androidx.core.app.ActivityCompat
import java.util.*
import java.util.concurrent.Executors

class MainActivity : AppCompatActivity() {
    private lateinit var previewView: PreviewView
    private lateinit var poseView: PoseView
    private lateinit var ortEnvironment: OrtEnvironment
    private lateinit var session: OrtSession

    private val dataProcess = DataProcess(context = this)

    companion object {
        const val PERMISSION = 1
    }

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        previewView = findViewById(R.id.previewView)
        poseView = findViewById(R.id.poseView)

        //자동꺼짐 해제
        window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)

        //권한 허용
        setPermissions()

        //모델 불러오기
        load()

        //카메라 켜기
        setCamera()
    }

    private fun setCamera() {
        val processCameraProvider =
            ProcessCameraProvider.getInstance(this).get()                        // 카메라 제공 객체

        previewView.scaleType =
            PreviewView.ScaleType.FILL_CENTER                                           // 전체화면

        val cameraSelector =
            CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK) // 후면 카메라
                .build()

        val preview =
            Preview.Builder().setTargetAspectRatio(AspectRatio.RATIO_16_9).build()      // 16:9 화면

        preview.setSurfaceProvider(previewView.surfaceProvider)

        val analysis = ImageAnalysis.Builder().setTargetAspectRatio(AspectRatio.RATIO_16_9)
            .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
            .build()   // 최신 화면으로 분석, 비율도 동일

        analysis.setAnalyzer(Executors.newSingleThreadExecutor()) {
            imageProcess(it)
            it.close()
        }
        // 생명주기 메인 액티비티에 귀속
        processCameraProvider.bindToLifecycle(this, cameraSelector, preview, analysis)

    }

    private fun imageProcess(imageProxy: ImageProxy) {

        val bitmap = dataProcess.imageToBitmap(imageProxy)
        val buffer = dataProcess.bitmapToFloatBuffer(bitmap)
        val inputName = session.inputNames.iterator().next()
        //모델의 요구 입력값 [1 3 640 640] [배치 사이즈, 픽셀(RGB), 너비, 높이], 모델마다 크기는 다를 수 있음.
        val shape = longArrayOf(
            DataProcess.BATCH_SIZE.toLong(),
            DataProcess.PIXEL_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong(),
            DataProcess.INPUT_SIZE.toLong()
        )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, buffer, shape)
        val resultTensor = session.run(Collections.singletonMap(inputName, inputTensor))
        val outputs = resultTensor[0].value as Array<*>
        val results = dataProcess.outputsToNPMSPredictions(outputs)

        poseView.setList(results)
        poseView.invalidate()
    }

    private fun load() {
        dataProcess.loadPoseModel()

        // 추론을 위한 객체 생성
        ortEnvironment = OrtEnvironment.getEnvironment()
        session =
            ortEnvironment.createSession(
                this.filesDir.absolutePath.toString() + "/" + DataProcess.FILE_NAME,
                OrtSession.SessionOptions()
            )
    }

    override fun onRequestPermissionsResult(
        requestCode: Int,
        permissions: Array<out String>,
        grantResults: IntArray
    ) {
        if (requestCode == PERMISSION) {
            grantResults.forEach {
                if (it != PackageManager.PERMISSION_GRANTED) {
                    Toast.makeText(this, "권한을 허용하지 않으면 사용할 수 없습니다!", Toast.LENGTH_SHORT).show()
                    finish()
                }
            }
        }
        super.onRequestPermissionsResult(requestCode, permissions, grantResults)
    }

    private fun setPermissions() {
        val permissions = ArrayList<String>()
        permissions.add(android.Manifest.permission.CAMERA)

        permissions.forEach {
            if (ActivityCompat.checkSelfPermission(this, it) != PackageManager.PERMISSION_GRANTED) {
                ActivityCompat.requestPermissions(this, permissions.toTypedArray(), PERMISSION)
            }
        }
    }
}
//데이터 처리 클래스 
import android.content.Context
import android.graphics.Bitmap
import android.graphics.RectF
import androidx.camera.core.ImageProxy
import java.io.File
import java.io.FileOutputStream
import java.nio.FloatBuffer
import java.util.*
import kotlin.collections.ArrayList
import kotlin.math.max
import kotlin.math.min

class DataProcess(val context: Context) {

    companion object {
        const val BATCH_SIZE = 1
        const val INPUT_SIZE = 640
        const val PIXEL_SIZE = 3
        const val FILE_NAME = "yolov8n-pose.onnx"
    }

    fun imageToBitmap(imageProxy: ImageProxy): Bitmap {
        val bitmap = imageProxy.toBitmap()
        return Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, true)
    }

    fun bitmapToFloatBuffer(bitmap: Bitmap): FloatBuffer {
        val imageSTD = 255f
        val buffer = FloatBuffer.allocate(BATCH_SIZE * PIXEL_SIZE * INPUT_SIZE * INPUT_SIZE)
        buffer.rewind()

        val area = INPUT_SIZE * INPUT_SIZE
        val bitmapData = IntArray(area)
        bitmap.getPixels(
            bitmapData,
            0,
            bitmap.width,
            0,
            0,
            bitmap.width,
            bitmap.height
        ) //배열에 RGB 담기

        //하나씩 받아서 버퍼에 할당
        for (i in 0 until INPUT_SIZE - 1) {
            for (j in 0 until INPUT_SIZE - 1) {
                val idx = INPUT_SIZE * i + j
                val pixelValue = bitmapData[idx]
                // 위에서 부터 차례대로 R 값 추출, G 값 추출, B값 추출 -> 255로 나누어서 0~1 사이로 정규화
                buffer.put(idx, ((pixelValue shr 16 and 0xff) / imageSTD))
                buffer.put(idx + area, ((pixelValue shr 8 and 0xff) / imageSTD))
                buffer.put(idx + area * 2, ((pixelValue and 0xff) / imageSTD))
                //원리 bitmap == ARGB 형태의 32bit, R값의 시작은 16bit (16 ~ 23bit 가 R영역), 따라서 16bit 를 쉬프트
                //그럼 A값이 사라진 RGB 값인 24bit 가 남는다. 이후 255와 AND 연산을 통해 맨 뒤 8bit 인 R값만 가져오고, 255로 나누어 정규화를 한다.
                //다시 8bit 를 쉬프트 하여 R값을 제거한 G,B 값만 남은 곳에 다시 AND 연산, 255 정규화, 다시 반복해서 RGB 값을 buffer 에 담는다.
            }
        }
        buffer.rewind()
        return buffer
    }

    fun loadPoseModel() {
        //onnx 파일 불러오기
        val assetManager = context.assets
        val outputFile = File(context.filesDir.toString() + "/" + FILE_NAME)

        assetManager.open(FILE_NAME).use { inputStream ->
            FileOutputStream(outputFile).use { outputStream ->
                val buffer = ByteArray(4 * 1024)
                var read: Int
                while (inputStream.read(buffer).also { read = it } != -1) {
                    outputStream.write(buffer, 0, read)
                }
            }
        }
    }

    fun outputsToNPMSPredictions(outputs: Array<*>): ArrayList<FloatArray> {
        val confidenceThreshold = 0.4f
        val rows: Int
        val cols: Int
        val results = ArrayList<FloatArray>()

        (outputs[0] as Array<*>).also {
            rows = it.size
            cols = (it[0] as FloatArray).size
        }

        //배열 형태를 [56 8400] -> [8400 56] 으로 변환
        val output = Array(cols) { FloatArray(rows) }
        for (i in 0 until rows) {
            for (j in 0 until cols) {
                output[j][i] = (((outputs[0] as Array<*>)[i]) as FloatArray)[j]
            }
        }


        for (i in 0 until cols) {
            // 바운딩 박스의 특정 확률을 넘긴 경우에만 xywh -> xy xy 형태로 변환 후 nms 처리
            if (output[i][4] > confidenceThreshold) {
                val xPos = output[i][0]
                val yPos = output[i][1]
                val width = output[i][2]
                val height = output[i][3]

                val x1 = max(xPos - width / 2f, 0f)
                val x2 = min(xPos + width / 2f, INPUT_SIZE - 1f)
                val y1 = max(yPos - height / 2f, 0f)
                val y2 = min(yPos + height / 2f, INPUT_SIZE - 1f)

                output[i][0] = x1
                output[i][1] = y1
                output[i][2] = x2
                output[i][3] = y2

                results.add(output[i])
            }
        }
        return nms(results)
    }

    private fun nms(results: ArrayList<FloatArray>): ArrayList<FloatArray> {
        val list = ArrayList<FloatArray>()
        //results 안에 있는 conf 값 중에서 제일 높은 애를 기준으로 NMS 가 겹치는 애들을 제거
        val pq = PriorityQueue<FloatArray>(5) { o1, o2 ->
            o1[4].compareTo(o2[4])
        }

        pq.addAll(results)

        while (pq.isNotEmpty()) {
            // 큐 안에 속한 최대 확률값을 가진 FloatArray 저장
            val detections = pq.toTypedArray()
            val max = detections[0]
            list.add(max)
            pq.clear()

            // 교집합 비율 확인하고 50% 넘기면 제거
            for (k in 1 until detections.size) {
                val detection = detections[k]
                val rectF = RectF(detection[0], detection[1], detection[2], detection[3])
                val maxRectF = RectF(max[0], max[1], max[2], max[3])
                val iouThreshold = 0.5f
                if (boxIOU(maxRectF, rectF) < iouThreshold) {
                    pq.add(detection)
                }
            }
        }
        return list
    }


    // 겹치는 비율 (교집합/합집합)
    private fun boxIOU(a: RectF, b: RectF): Float {
        return boxIntersection(a, b) / boxUnion(a, b)
    }

    //교집합
    private fun boxIntersection(a: RectF, b: RectF): Float {
        // x1, x2 == 각 rect 객체의 중심 x or y값, w1, w2 == 각 rect 객체의 넓이 or 높이
        val w = overlap(
            (a.left + a.right) / 2f, a.right - a.left,
            (b.left + b.right) / 2f, b.right - b.left
        )

        val h = overlap(
            (a.top + a.bottom) / 2f, a.bottom - a.top,
            (b.top + b.bottom) / 2f, b.bottom - b.top
        )
        return if (w < 0 || h < 0) 0f else w * h
    }

    //합집합
    private fun boxUnion(a: RectF, b: RectF): Float {
        val i = boxIntersection(a, b)
        return (a.right - a.left) * (a.bottom - a.top) + (b.right - b.left) * (b.bottom - b.top) - i
    }

    //서로 겹치는 길이
    private fun overlap(x1: Float, w1: Float, x2: Float, w2: Float): Float {
        val l1 = x1 - w1 / 2
        val l2 = x2 - w2 / 2
        val left = max(l1, l2)
        val r1 = x1 + w1 / 2
        val r2 = x2 + w2 / 2
        val right = min(r1, r2)
        return right - left
    }
}
//pose 포인트가 보일 pose View
import android.annotation.SuppressLint
import android.content.Context
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.util.AttributeSet
import android.view.View

class PoseView(context: Context, attributeSet: AttributeSet) : View(context, attributeSet) {

    private var list: ArrayList<FloatArray>? = null
    private val pointPaint = Paint().apply {
        color = Color.RED
        style = Paint.Style.STROKE
        strokeWidth = 10f
    }
    private val linePaint = Paint().apply {
        color = Color.YELLOW
        style = Paint.Style.FILL
        strokeWidth = 5f
    }

    fun setList(list: ArrayList<FloatArray>) {
        this.list = list
    }

    @SuppressLint("DrawAllocation")
    override fun onDraw(canvas: Canvas?) {
        //점, 선 그리기
        drawPointsAndLines(canvas)
        super.onDraw(canvas)
    }

    private fun drawPointsAndLines(canvas: Canvas?) {
        val scaleX = width / DataProcess.INPUT_SIZE.toFloat()
        val scaleY = scaleX * 9f / 16f
        val realY = width * 9f / 16f
        val diffY = realY - height

        val kPointsThreshold = 0.35f
        list?.forEach {
            val points = FloatArray(34)
            for ((a, i) in (points.indices step 2).withIndex()) {
                if (it[i + 7 + a] > kPointsThreshold) {
                    points[i] = it[i + 5 + a] * scaleX
                    points[i + 1] = it[i + 6 + a] * scaleY - (diffY / 2f)
                }
            }
            drawPoint(canvas, points)
            drawLines(canvas, points)
        }
    }

    private fun drawPoint(canvas: Canvas?, points: FloatArray) {
        for (i in points.indices step 2) {
            val xPos = points[i]
            val yPos = points[i + 1]
            if (xPos > 0 && yPos > 0) {
                canvas?.drawPoint(xPos, yPos, pointPaint)
            }
        }
    }

    private fun drawLines(canvas: Canvas?, points: FloatArray) {
        // 점과 점사이에 직선 그리기
        // keypoint 순서
        // 0번 == 코
        // 1번 == 오른쪽 눈
        // 2번 == 왼쪽 눈
        // 3번 == 오른쪽 귀
        // 4번 == 왼쪽 귀
        // 5번 == 오른쪽 어깨
        // 6번 == 왼쪽 어깨
        // 7번 == 오른쪽 팔꿈치
        // 8번 == 왼쪽 팔꿈치
        // 9번 == 오른쪽 손목
        // 10번 == 왼쪽 손목
        // 11번 == 오른쪽 골반
        // 12번 == 왼쪽 골반
        // 13번 == 오른쪽 무릎
        // 14번 == 왼쪽 무릎
        // 15번 == 오른쪽 발
        // 16번 == 왼쪽 발

        // 코, 오른쪽 눈 연결
        var startX = points[0]
        var startY = points[1]
        var stopX = points[2]
        var stopY = points[3]
        drawLine(startX, startY, stopX, stopY, canvas)
        // 코, 왼쪽 눈 연결
        startX = points[0]
        startY = points[1]
        stopX = points[4]
        stopY = points[5]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 눈 귀 연결
        startX = points[2]
        startY = points[3]
        stopX = points[8]
        stopY = points[9]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 눈 귀 연결
        startX = points[4]
        startY = points[5]
        stopX = points[8]
        stopY = points[9]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 귀 어깨 연결
        startX = points[6]
        startY = points[7]
        stopX = points[10]
        stopY = points[11]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 귀 어깨 연결
        startX = points[8]
        startY = points[9]
        stopX = points[12]
        stopY = points[13]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 어깨 팔꿈치 연결
        startX = points[10]
        startY = points[11]
        stopX = points[14]
        stopY = points[15]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 어깨 팔꿈치 연결
        startX = points[12]
        startY = points[13]
        stopX = points[16]
        stopY = points[17]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 어깨 골반 연결
        startX = points[10]
        startY = points[11]
        stopX = points[22]
        stopY = points[23]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 어깨 골반 연결
        startX = points[12]
        startY = points[13]
        stopX = points[24]
        stopY = points[25]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 팔꿈치 손목 연결
        startX = points[14]
        startY = points[15]
        stopX = points[18]
        stopY = points[19]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 팔꿈치 손목 연결
        startX = points[16]
        startY = points[17]
        stopX = points[20]
        stopY = points[21]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 골반 무릎 연결
        startX = points[22]
        startY = points[23]
        stopX = points[26]
        stopY = points[27]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 골반 무릎 연결
        startX = points[24]
        startY = points[25]
        stopX = points[28]
        stopY = points[29]
        drawLine(startX, startY, stopX, stopY, canvas)
        //오른쪽 무릎 발 연결
        startX = points[26]
        startY = points[27]
        stopX = points[30]
        stopY = points[31]
        drawLine(startX, startY, stopX, stopY, canvas)
        //왼쪽 무릎 발 연결
        startX = points[28]
        startY = points[29]
        stopX = points[32]
        stopY = points[33]
        drawLine(startX, startY, stopX, stopY, canvas)
        //어깨 좌우 연결
        startX = points[10]
        startY = points[11]
        stopX = points[12]
        stopY = points[13]
        drawLine(startX, startY, stopX, stopY, canvas)
        //골반 좌우 연결
        startX = points[22]
        startY = points[23]
        stopX = points[24]
        stopY = points[25]
        drawLine(startX, startY, stopX, stopY, canvas)
    }

    private fun drawLine(
        startX: Float,
        startY: Float,
        stopX: Float,
        stopY: Float,
        canvas: Canvas?
    ) {
        if (startX > 0 && startY > 0 && stopX > 0 && stopY > 0) {
            canvas?.drawLine(startX, startY, stopX, stopY, linePaint)
        }
    }

}

아래는 파일의 일부이다. 3개의 클래스를 사용했고, assets안에는 모델파일이 들어있다.

아래는 깃허브 주소다.
https://github.com/Yurve/YOLOv8_Pose_android

지금까지 글을 정리하자면 아래와 같다.
1편: 모델의 onnx변환, 카메라 및 화면 설정, 입력 사진에 대한 적절한 처리
2편: 출력 배열에 대해 적절한 처리 및 화면에 표출

4개의 댓글

comment-user-thumbnail
2023년 4월 10일

오! aloe님의 글이 도움이 많이 되었습니다.
괜찮으시면 후의 내용들도 계속 올려주시면 감사하겠습니다.

1개의 답글
comment-user-thumbnail
2024년 8월 23일

큰 도움 되었습니다 정말 감사합니다.

1개의 답글