JS로 만드는 AI : 6

KHW·2021년 1월 7일
0

데이터분석

목록 보기
7/13
post-custom-banner

전체코드

// 1. 과거의 데이터를 준비합니다.
var tf = require('@tensorflow/tfjs');
var 온도 = [20,21,22,23];
var 판매량 = [40,42,44,46];
var 원인 = tf.tensor(온도);
var 결과 = tf.tensor(판매량);

// 2. 모델의 모양을 만듭니다.
var X = tf.input({ shape: [1] });
var Y = tf.layers.dense({ units: 1 }).apply(X);
var model = tf.model({ inputs: X, outputs: Y });
var compileParam = { optimizer: tf.train.adam(), loss: tf.losses.meanSquaredError }
model.compile(compileParam);

// 3. 데이터로 모델을 학습시킵니다.
var fitParam = { epochs: 4000,
    callbacks:{
        onEpochEnd:
            function(epoch, logs){
                console.log('epoch', epoch, logs, 'RMSE=>', Math.sqrt(logs.loss));
            }
    }
}

// var fitParam = { epochs: 100, callbacks:{onEpochEnd:function(epoch, logs){console.log('epoch', epoch, logs);}}} // loss 추가 예제
model.fit(원인, 결과, fitParam).then(function (result) {

    // 4. 모델을 이용합니다.
    // 4.1 기존의 데이터를 이용
    var 예측한결과 = model.predict(원인);
    예측한결과.print();

});

주요 개념

  1. MSE : 평균 제곱 오차 (log.loss)
    원인과 결과 데이터를 통한 모델을 예측했을때 실제 예측값과 결과간의 차이를 제곱하여 평균 낸 것
  2. RMSE : 평균 제곱근 오차 (Math.sqrt(log.loss))
    위의 내용에서 제곱근을 추가한 것

● log의 결과값의 경우는 {loss: 숫자}형태로 나타난다. 따라서 Math.sqrt를 계산하기 위해서는 정확한 숫자의 형태인 log.loss를 사용해야한다.

이를 통해 시간이 지날때마다 epoch를 확인 할 수 있고 RMSE와 같은값들의 변화도 확인 할 수 있다.

출처 : 생활코딩

profile
나의 하루를 가능한 기억하고 즐기고 후회하지말자
post-custom-banner

0개의 댓글