칼만필터 파이썬 코드로 구현하기(수식 사용)

hh_mon__a·2024년 6월 20일
0

칼만필터

목록 보기
3/3

칼만필터 정의 및 알고리즘

  • 앞서 칼만필터 정의 및 알고리즘 내용을 살펴봄. 미리 칼만필터에 대해 알고 이 글을 보는 것을 추천함.
  • 그것에 대한 수식을 패키지를 사용하지 않고 구현함.
  • 실제 온도에 대해 칼만필터를 적용함.

1. 데이터 불러오기

# 필요데이터
true_temperature = np.array(df.col[:600])
observed_temperature = true_temperature + np.random.normal(0, 0.5, len(true_temperature))
  • true_temperature는 실제 온도로 지정
  • measurements는 실제온도에서 노이즈를 추가해줌
    • 현실 시계에서 측정된 데이터는 항상 일정한 양의 오차나 노이즈를 포함하고 있기 때문임
    • 센서로 측정한 데이터는 다양한 요인(센서의 정확도, 환경적인 영향 등)으로 인해 실제값과 차이가 날 수 있기 때문에 이러한 오차를 모델링하기 위해 노이즈를 추가
    • 모델의 강건성 테스트: 노이즈가 포함된 데이터를 사용함으로써 모델의 강건성을 테스트할 수 있음. 좋은 모델은 노이즈가 있는 데이터에서도 정확한 예측을 할 수 있어야 함.

2. 칼만필터 구현

def kalman_filter(z, Q, R, A, H):
    ts_length = len(z)
    dim_state = Q.shape[0]
    
    xhatminus = np.zeros((ts_length, dim_state)) # 예측된 상태 추정값
    xhat = np.zeros((ts_length, dim_state)) # 필터링된 상태 추정값
    Pminus = np.zeros((ts_length, dim_state, dim_state))
    P = np.zeros((ts_length, dim_state, dim_state))
    K = np.zeros((ts_length, dim_state))  # Kalman gain
    
    # 초기 추정
    xhat[0, :] = z[0]
    xhatminus[0, :] = z[0]
    P[0, :, :] = np.eye(dim_state)
    
    # 시간 갱신
    for k in range(1, ts_length):
        # 예측 단계(Prediction step)
        xhatminus[k, :] = A @ xhat[k-1, :]
        Pminus[k, :, :] = A @ P[k-1, :, :] @ A.T + Q
        
        # 보정 단계(Correction step)
        K[k, :] = Pminus[k, :, :] @ H.T @ np.linalg.inv(H @ Pminus[k, :, :] @ H.T + R)
        xhat[k, :] = xhatminus[k, :] + K[k, :] @ (z[k] - H @ xhatminus[k, :])
        P[k, :, :] = (np.eye(dim_state) - K[k, :][:, np.newaxis] @ H[np.newaxis, :]) @ Pminus[k, :, :]
    
    return xhat, xhatminus, P, Pminus
  • 초기값은 데이터 중 첫번째 값으로 진행
  • 필터링된 상태 추정값(xhat)
  • 예측된 상태 추정값(xhatminus)
  • 필터링된 오차 공분산(P)
  • 칼만필터 정의 및 알고리즘에서 나온 수식을 파이썬 코드로 진행

3. 파라미터 설정

# 노이즈 파라미터
R = 1  # 측정 분산
Q = np.diag([1])  # 과정 분산

# 동적 파라미터
A = np.array([[1]])  # 상태 전이 행렬
H = np.array([[1]])  # 관측 행렬

4. 칼만필터 적용 및 결과 확인

# 칼만 필터 적용
xhat, xhatminus, P = kalman_filter(observed_temperature, Q, R, A, H)

# 결과 시각화
plt.figure(figsize=(20, 6))
plt.plot(true_temperature, label='True Temperature')
plt.plot(observed_temperature, label='Observed Temperature', alpha=0.5)
plt.plot(xhat, label='Kalman Filter')
plt.plot(xhatminus, label='Kalman Predict')
plt.legend()
plt.show()

5. 성능 확인

# 예측 성능 비교
def mean_squared_error(true, pred):
    return np.mean((true - pred) ** 2)

mean_squared_error(true_temperature, xhatminus.reshape(-1,))
# 0.3710657788923199

전체 코드

칼만필터 깃허브

  • 전체 코드 및 실행 코드는 저의 깃허브 참고해주시면 됩니다!

참고자료

참고자료(R로 구현)

profile
데이터분석/데이터사이언스/코딩

0개의 댓글