mask rcnn 하이퍼 파리미터

우수민·2021년 3월 29일
0
post-custom-banner

Mask R-CNN의 경우 config파일에 있듯 조절할 수 있는 파라미터가 많이 존재한다. 그 중에서도 아래에 있는 내용들을 시도해볼만하다고 생각한다. (코드부분은 Mask R-CNN에 주석 처리 된 내용)

1. Back Bone

: 사용할 수 있는 옵션으로는 ResNet50, ResNet101, and ResNext 101이 있는데 ResNet50과 ResNet101, ResNext 101은 당연하게도 정확도와 속도에 trade-off 관계가 있어서 필요에 따라 선택하면 될 것같다. 또한 pre-trained weights(coco나 imagenet)이 없고 기본 매개변수(learning rate나 epoch)가 잘 조정된 경우 더 정확한 경향이 있다고 한다.

    # Backbone network architecture
    # Supported values are: resnet50, resnet101.
    # You can also provide a callable that should have the signature
    # of model.resnet_graph. If you do so, you need to supply a callable
    # to COMPUTE_BACKBONE_SHAPE as well
    BACKBONE = "resnet101"

2. Train_ROIs_Per_Image

: 이는 rois의 최대수를 의미한다. 초반에는 인스턴스의 수를 정확히 알 수 없어 기본값으로 시작하며, 추후엔 인스턴스수를 조절한다면 훈련 시간을 줄일 수 있다.

    # Number of ROIs per image to feed to classifier/mask heads
    # The Mask RCNN paper uses 512 but often the RPN doesn't generate
    # enough positive proposals to fill this and keep a positive:negative
    # ratio of 1:3. You can increase the number of proposals by adjusting
    # the RPN NMS threshold.
    TRAIN_ROIS_PER_IMAGE = 200

3. Max_GT_Instances

: 이는 하나의 이미지에서 발견 될 수 있는 인스턴스의 최대값이다. 만약에 custom data에 활용할 경우 많은 object가 존재하지 않는다면 이미지당 최대 인스턴스를 줄일 수 있다.

    # Maximum number of ground truth instances to use in one image
    MAX_GT_INSTANCES = 100

4. Detection_Min_Confidence

: 이는 인스턴스를 분류할때 해당 정확도 이상으로만 Detection한다는 것을 의미한다. 모든것이 발견이는 것이 중요하고 false postive일 경우엔 threshold를 줄이는 것이 좋다. 반대로 정확성이 중요한 경우엔 threshold를 높여 높은 신뢰도로 높여 예측을 하는 것이 좋다.

    # Minimum probability value to accept a detected instance
    # ROIs below this threshold are skipped
    DETECTION_MIN_CONFIDENCE = 0.7

5. Image_Min_Dim and Image_Max_Dim :

단순히 이미지의 크기를 셋팅하는 부분이다. 디폴트는 10241024이지만 512512로 하는 경우 메모리 사용을 줄이고 훈련 속도를 향상 시킬수 있다. 이상적인 접근 방식은 빠른 가중치 업데이를 위해 작은 이미지 크기에서 초기 모델을 훈련시키고 최종 단계에서 큰 사이즈를 사용하는 것이다.

    # Input image resizing
    # Generally, use the "square" resizing mode for training and predicting
    # and it should work well in most cases. In this mode, images are scaled
    # up such that the small side is = IMAGE_MIN_DIM, but ensuring that the
    # scaling doesn't make the long side > IMAGE_MAX_DIM. Then the image is
    # padded with zeros to make it a square so multiple images can be put
    # in one batch.
    # Available resizing modes:
    # none:   No resizing or padding. Return the image unchanged.
    # square: Resize and pad with zeros to get a square image
    #         of size [max_dim, max_dim].
    # pad64:  Pads width and height with zeros to make them multiples of 64.
    #         If IMAGE_MIN_DIM or IMAGE_MIN_SCALE are not None, then it scales
    #         up before padding. IMAGE_MAX_DIM is ignored in this mode.
    #         The multiple of 64 is needed to ensure smooth scaling of feature
    #         maps up and down the 6 levels of the FPN pyramid (2**6=64).
    # crop:   Picks random crops from the image. First, scales the image based
    #         on IMAGE_MIN_DIM and IMAGE_MIN_SCALE, then picks a random crop of
    #         size IMAGE_MIN_DIM x IMAGE_MIN_DIM. Can be used in training only.
    #         IMAGE_MAX_DIM is not used in this mode.
        IMAGE_RESIZE_MODE = "square"
        IMAGE_MIN_DIM = 800
        IMAGE_MAX_DIM = 1024

6. Loss Weights

    # Loss weights for more precise optimization.
    # Can be used for R-CNN training setup.
    LOSS_WEIGHTS = {
        "rpn_class_loss": 1.,
        "rpn_bbox_loss": 1.,
        "mrcnn_class_loss": 1.,
        "mrcnn_bbox_loss": 1.,
        "mrcnn_mask_loss": 1.
    }

6-1. rpn_class_loss

: 이는 모델이 각 단계에 할당해야 하는 가중치에 해당한다.

6-2. Rpn_class_loss

: 이는 anchor boxes의 부적절한 분류(객체의 유무) 를 RPN에 의한 loss에 해당한다. 최종 출력에서 모델이 여러 객체를 탐지 못할 경우 이를 올려야 한다. 이 값이 증가할 때 RPN이 탐지할 수 있다.

6-3. rpn_bbox_loss

: 이는 localization 정확도를 의미한다. 물체는 감지하지만 경계상자가 부정확할 경우 조정하는 가중치이다.

6-4. mrcnn_class_loss

: 이는 이미지에서 물체가 감지되었지만 잘못 분류 된 경우 증가한다.

6-5. mrcnn_bbox_loss

: 이는 객체의 정확한 분류가 이루어지면 증가하지만 lcalization이 정확하지 않다.

6-6. mrcnn_mask_loss

: 이는 생성된 마스크를 의미하며 픽셀 수준에서 식별이 중요하다면 가중치를 증가 시켜야 한다.

참고 링크 : https://medium.com/analytics-vidhya/taming-the-hyper-parameters-of-mask-rcnn-3742cb3f0e1b

profile
데이터 분석하고 있습니다
post-custom-banner

0개의 댓글