[Python] Streamlit 기본 문법 (text, data, cache, session)

tu11p·2024년 8월 8일
0

Python

목록 보기
1/1
post-thumbnail

Streamlit?

데이터 과학자와 AI/ML 엔지니어가 단 몇 줄의 코드만으로 동적 데이터 앱을 제공할 수 있는 오픈 소스 Python 프레임워크

  • 복잡한 프론트 코딩 없이 간단/간결하게 페이지 만들어 AI 모델을 서빙할 수 있음
  • 세밀하고 복잡한 UI/UX/디자인 서비스 구현 어려움
  • 사용자 많을 경우 대응 어려움

0. 데이터 출력 write

a = 10
b = 3
st.write('a*b-a = ', a*b-a)

1. Text 다루기

텍스트 Text

st.text('This is text')
st.markdown("This is Markdown text")

마크다운 markdown

st.markdown("This is Markdown text")
#볼드
st.markdown("**This is BOLD Markdown text**")
#이탤릭
st.markdown("*This is ITALIC Markdown text*")
st.markdown("_This is ITALIC Markdown text_")
#볼드&이탤릭
st.markdown("**_This is BOLD & ITALIC Markdown text_**")

페이지 제목 title

st.title('This is title')
st.markdown("# This is Markdown title")

헤더 header

st.header('This is header')
st.markdown("## This is Markdown header")

sub헤더 subheader

st.subheader('This is subheader')
st.markdown("### This is Markdown subheader")

캡션 caption

st.caption('This is caption')
## markdown 사용 불가 - 이미지, 테이블, 차트 캡션 용도로 사용되므로

LaTeX 수식 latex

st.latex(r'''\sqrt[n]{x}''')

코드 스니펫 code

text='''print('hello world!')'''
st.code(text)

글머리 기호(ul, li) by markdown

st.markdown('- 1st \n'
			'  - 2nd \n' 		#공백 2칸
			'    - 3rd \n') #공백 4칸

숫자 리스트(ol, li) by markdown

st.markdown('- 1st \n'
			'   - 1st \n' 		#공백 3칸
			'      - 1st \n') #공백 6칸

2. 다양한 Data + Media 다루기 (Dataframe, 이미지, 동영상 … )

Dataframe

data = {
    'Name': ['Alice', 'Bob', 'Charlie', 'David', 'Eva'],
    'Age': [25, 30, 35, 40, 45],
    'City': ['New York', 'Los Angeles', 'Chicago', 'Houston', 'Phoenix']
}

df = pd.DataFrame(data)

'''
# static 방법
st.write("DataFrame using st.write:")
st.write(df)
'''

# 동적 방법
st.write("DataFrame using st.dataframe:")
st.dataframe(df)

오디오 audio

st.audio('audio.mp3')

동영상 video

st.video('video.mp4')

이미지 image

st.image('경로', caption='캡션')

표 table

data = {
    'Name': ['Alice', 'Bob', 'Charlie'],
    'Age': [28, 34, 22],
    'Job': ['Engineer', 'Doctor', 'Artist']
}
table = pd.DataFrame(data)

st.table(table)
  • st.table은 static table을 표시하는 가장 기본적인 방식이라고 한다. 대부분의 동적 테이블은 st.dataframe으로 구현하며, 사용자가 편집하는 dataframe은 st.data_editor을 사용한다.

3. Input 위젯 다루기

버튼 button

st.button("Reset button", type="primary")
if st.button("switch"):
    st.write("change text")
else:
    st.write("reset")

다운로드 버튼 download_button

#dataframe을 csv로 변환하여 다운로드

@st.cache_data     #convert_df 결과를 캐싱하여 나중에도 사용
def convert_df(df):
    return df.to_csv().encode("utf-8")

csv = convert_df(my_large_df)

st.**download_button**(
    label="Download data as CSV",
    data=csv,
    file_name="large_df.csv",
    mime="text/csv",
)
st.page_link("app.py", label="Home", icon="🏠")

입력 폼 form

체크박스 checkbox

토글 toggle

라디오 버튼 radio

드롭다운 selectbox

텍스트 입력받기 text_input

양식 제출 버튼 form_submit_button

with st.form(key='form 식별 값'):
    st.write("모든 input field를 채우세요")

    checkbox_val = st.checkbox('checkbox 입니다')
    toggle_val = st.toggle('toggle 입니다')
    radio_val = st.radio('radio 입니다:', ['Option 1', 'Option 2', 'Option 3'])
    selectbox_val = st.selectbox('selectbox 입니다', ['Red', 'Green', 'Blue'])
    text_input_val = st.text_input('text input 입니다')

    # Submit button
    submit_button = st.form_submit_button(label='제출')

#input field 검증
if submit_button:
    if checkbox_val and toggle_val and radio_val and selectbox_val and text_input_val:
        st.success('제출되었습니다')
    else:
        st.error('모든 field를 채우세요')

★☆채팅 입력 받기☆★ chat_input, chat_massage

prompt = st.chat_input("메시지를 입력하세요.") #placeholder
if prompt:
    with st.chat_message("user"): #사용자 메시지 컨테이너
        st.write(prompt)
    with st.chat_message("ai", avatar="🤖"): #인공지능 메시지 컨테이너
        st.write("이것은 인공지능 응답입니다.")
    • 채팅 관련 3rd-party component

      pip install streamlit-chat 
      
      from streamlit_chat import message
      
      message("My message") 
      message("Hello bot!", is_user=True)  # align's the message to the right
      
      input_text = st.text_input("You: ","Hello, how are you?", key="input")

4. 캐싱, 세션 관리

데이터 캐싱 cache_data

  • 데이터 한 번 로딩되면 그 데이터는 캐싱해서 빠르게 로드 가능
@st.cache_data
def load_data():
    time.sleep(5)  # 5초 딜레이
    data = pd.DataFrame({
        'col1': range(1000),
        'col2': range(1000, 2000)
    })
    return data

data = load_data()

st.write(data)

리소스 캐싱 cache_resource

  • AI 모델, DB 등의 리소스(반환 결과 등)를 효율적으로 사용하여 응답 시간 단축, 성능 최적화
#사이킷런 캘리포니아 집 값 예측 모델 학습 예제
def load_data():
    california = fetch_california_housing()
    X = pd.DataFrame(california.data, columns=california.feature_names)
    y = pd.Series(california.target, name='target')
    return X, y

@st.cache_resource
def train_model(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    return model, mse

def main():
    st.title("California Housing Price")

    X, y = load_data()

    model, mse = train_model(X, y)
    st.write(f"Train MSE: {mse:.2f}")

    st.header("Input Features")
    MedInc = st.number_input("MedInc", float(X['MedInc'].min()), float(X['MedInc'].max()), float(X['MedInc'].mean()))
    HouseAge = st.number_input("HouseAge", float(X['HouseAge'].min()), float(X['HouseAge'].max()), float(X['HouseAge'].mean()))
    AveRooms = st.number_input("AveRooms", float(X['AveRooms'].min()), float(X['AveRooms'].max()), float(X['AveRooms'].mean()))
    AveBedrms = st.number_input("AveBedrms", float(X['AveBedrms'].min()), float(X['AveBedrms'].max()), float(X['AveBedrms'].mean()))
    Population = st.number_input("Population", float(X['Population'].min()), float(X['Population'].max()), float(X['Population'].mean()))
    AveOccup = st.number_input("AveOccup", float(X['AveOccup'].min()), float(X['AveOccup'].max()), float(X['AveOccup'].mean()))
    Latitude = st.number_input("Latitude", float(X['Latitude'].min()), float(X['Latitude'].max()), float(X['Latitude'].mean()))
    Longitude = st.number_input("Longitude", float(X['Longitude'].min()), float(X['Longitude'].max()), float(X['Longitude'].mean()))

    input_data = pd.DataFrame({
        'MedInc': [MedInc],
        'HouseAge': [HouseAge],
        'AveRooms': [AveRooms],
        'AveBedrms': [AveBedrms],
        'Population': [Population],
        'AveOccup': [AveOccup],
        'Latitude': [Latitude],
        'Longitude': [Longitude]
    })

    st.write("Input Features")
    st.write(input_data)

    if st.button("Predict"):
        prediction = model.predict(input_data)
        st.write(f"Predicted House Price: ${prediction[0]*100000:.2f}")

if __name__ == "__main__":
    main()

세션 관리 session_state

  • 사용자 입력이나 데이터가 새로고침 후에도 유지되도록 세션 관리
def load_data():
    california = fetch_california_housing()
    X = pd.DataFrame(california.data, columns=california.feature_names)
    y = pd.Series(california.target, name='target')
    return X, y

@st.cache_resource
def train_model(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    return model, mse

def main():
    # 세션 초기화
    if 'page_refresh_count' not in st.session_state:
        st.session_state.page_refresh_count = 0
    if 'predict_attempt_count' not in st.session_state:
        st.session_state.predict_attempt_count = 0
    
    st.session_state.page_refresh_count += 1

    st.title("California Housing Price")
    

    X, y = load_data()

    model, mse = train_model(X, y)
    st.write(f"Train MSE: {mse:.2f}")

    st.header("Input Features")
    MedInc = st.number_input("MedInc", float(X['MedInc'].min()), float(X['MedInc'].max()), float(X['MedInc'].mean()))
    HouseAge = st.number_input("HouseAge", float(X['HouseAge'].min()), float(X['HouseAge'].max()), float(X['HouseAge'].mean()))
    AveRooms = st.number_input("AveRooms", float(X['AveRooms'].min()), float(X['AveRooms'].max()), float(X['AveRooms'].mean()))
    AveBedrms = st.number_input("AveBedrms", float(X['AveBedrms'].min()), float(X['AveBedrms'].max()), float(X['AveBedrms'].mean()))
    Population = st.number_input("Population", float(X['Population'].min()), float(X['Population'].max()), float(X['Population'].mean()))
    AveOccup = st.number_input("AveOccup", float(X['AveOccup'].min()), float(X['AveOccup'].max()), float(X['AveOccup'].mean()))
    Latitude = st.number_input("Latitude", float(X['Latitude'].min()), float(X['Latitude'].max()), float(X['Latitude'].mean()))
    Longitude = st.number_input("Longitude", float(X['Longitude'].min()), float(X['Longitude'].max()), float(X['Longitude'].mean()))

    input_data = pd.DataFrame({
        'MedInc': [MedInc],
        'HouseAge': [HouseAge],
        'AveRooms': [AveRooms],
        'AveBedrms': [AveBedrms],
        'Population': [Population],
        'AveOccup': [AveOccup],
        'Latitude': [Latitude],
        'Longitude': [Longitude]
    })

    st.write("Input Features")
    st.write(input_data)

    if st.button("Predict"):
        st.session_state.predict_attempt_count += 1
        prediction = model.predict(input_data)
        st.write(f"Predicted House Price: ${prediction[0]*100000:.2f}")
        st.write(f"Predict Attempt Count: {st.session_state.predict_attempt_count}")
    
    st.write(f"Page Refresh Count: {st.session_state.page_refresh_count}")

if __name__ == "__main__":
    main()

0개의 댓글