[streamlit] Mask Classification Model로 prototype 만들기

MinI0123·2023년 4월 26일
0
post-thumbnail

Streamlit Library

Streamlit

Streamlit은 파이썬 open source app framework이다. Data science와 machine learning 분야에서 웹을 쉽고 빠르게 만들고 배포까지 할 수 있게 해주는 library라고 할 수 있다. 다양한 위젯이 구현되어 있어 필요에 따라 가져다 사용하면 된다.

사용법

  1. 설치
conda install -c conda-forge streamlit
  1. 코드 작성
    streamlit api를 사용하여 원하는 웹 화면을 구성한다.
# app.py
import streamlit as st

def main():
    st.title("Mask Classification Model")

    st.write("hello!")

main()
  1. 실행
streamlit run app.py

터미널에서 위와 같이 app.py를 실행시키면 웹 브라우저에서 작성한 웹 페이지를 확인할 수 있다.

동작 원리


공식 문서에 따르면 Streamlit은 화면 업데이트가 필요할 때마다 python script 전체를 다시 실행하는 방식을 사용한다고 한다. 화면 업데이트가 필요한 경우는 2가지로 다음과 같다.

  1. app의 소스코드가 변경된 경우

  2. 사용자가 app의 위젯과 상호작용한 경우

매번 python script를 다시 실행하기 때문에 변수에 값을 저장해놔도 초기화 되는 문제가 발생한다. 예를 들어 다음과 같이 누르면 값을 증가/감소시키는 버튼을 만든다고 해보자.

# app.py 
import streamlit as st

def main():
    num = 0
    st.title("button test")

    up_button = st.button("up")
    if up_button:
        num += 1
    down_button = st.button("down")
    if down_button:
        num -= 1

    st.write(num)

main()

버튼을 누르기 전에는 값이 0이다. up버튼을 누르면 값이 1 증가하고 down 버튼을 누르면 값이 1감소하기를 원한다. 그러나 버튼을 누르면 매번 값이 0으로 초기화되서 실행되므로 값은 1과 -1만 나오게 된다. 이와 같은 문제를 해결하기 위해서는 Session state를 사용한다.

Session state는 매 실행에서 공유되는 일종의 전역변수라고 생각하면 된다. 다음과 같이 사용할 수 있다.

# session_state
st.session_state.num = 1
st.write(st.session_state.num)

버튼을 누르면 값을 변경하는 위의 예제를 session state를 사용한 방식으로 변경해보자. streamlit의 widget은 callback 함수를 지정할 수 있다. callback함수를 지정한다고 해서 상호작용시 callback 함수만 호출되는 것은 아니고 callback 함수 호출 뒤에 script 전체가 다시 실행된다. 따라서 callback 함수에서는 session state를 변경해주는 것이 좋다. Callback과 session state를 사용하여 수정한 코드는 다음과 같다.

# app.py

import streamlit as st

def up_button_callback():
    st.session_state.num += 1

def down_button_callback():
    st.session_state.num -= 1

def main():
    if 'num' not in st.session_state:
        st.session_state.num = 0

    st.title("button test")

    up_button = st.button("up", on_click=up_button_callback)
    down_button = st.button("down", on_click=down_button_callback)

    st.write(st.session_state.num)

main()

Mask Classification Prototype

Streamlit을 사용하여 Mask classification 모델을 서빙하는 간단한 웹을 만들어 보자.

Code

Result

원하는 사진을 업로드하여 Mask Classification 결과를 확인 할 수 있다. Mask Classification Competition에 최종 제출한 모델은 6개의 모델을 ensemeble 해야 하는데 inference 시간을 고려하여 그냥 단일 모델로 inference를 하였다.

업로드 하는 사진의 사이즈도 모두 다르기 때문에 이미지의 특정 위치를 crop하는 MyCrop(Mask Classification 최종 제출 참고)보다는 center crop이 더 좋은 성능을 보여 center crop을 대신 사용하였다.

0개의 댓글