Python 경사하강법(BGD), 확률적경사하강법(SGD), 미니배치 구현하기

Surf in Data·2022년 4월 3일
0

deep learning

목록 보기
1/9
post-thumbnail

경사하강법의 종류

  • BGD(Batch Gradient Descent):
    전체 학습 데이터를 기반으로 계산
    예를들어 피처가 100개라고 생각하면 뉴럴네트워크에서 weight의 개수가 기하급수적으로 증가한다. 이를 GD로 계산하는데는 많은 computing자원이 소모된다. 이를 극복하기 위해 SGD와 Mini-Batch 방법론이 생성되었다.
  • SGD(Stochastic Gradient Descent):
    전체 학습 데이터 중 한 건만 임의로 선택하여 계산
  • Mini-Batch GD:
    전체 학습 데이터 중 특정 크기만큼(Batch 크기)임의로 선택해서 계산
    (대부분의 딥러닝 Framework에서는 MiniBatch 사용)

BGD(Batch Gradient Descent):

데이터셋 모두를 대상으로 하다 보니 계산해야 할 값이 지나치게 많아 계산의 시간도 길어지고 소모되는 메모리의 양도 많다.
안정적으로 수렴하지만 안정적으로 움직이는 만큼 지역 최소해(Local Minimum)에 빠지더라도 빠져나오기 힘들어 Local Optima(minimum)문제가 발생할 가능성이 높다.

import numpy as np
import matplotlib as plt

# feature 값 생성
np.random.seed(1)
x1 = np.random.rand(100)
x2 = np.random.rand(100)
x3 = np.random.rand(100)

#  다항식 정의(도출하고자 하는 값)
y = 0.3*x1 + 0.5*x2 + 0.6*x3 + 0.8

#임의의 weight 값 생성
w1 = np.random.uniform(low=-1.0, high=1.0)
w2 = np.random.uniform(low=-1.0, high=1.0)
w3 = np.random.uniform(low=-1.0, high=1.0)

# bias값 생성
bias = np.random.uniform(low=-1.0, high=1.0)
print("구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8")
print(f"BGD 시작 다항식:Y={w1}X1+{w2}X2+{w3}+X3+{bias}")
구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8
BGD 시작 다항식:Y=0.6237173954410795X1+0.7499232899117962X2+0.3768265047718866+X3+0.13898882549075142
num_epoch=5000
learning_rate=0.5

for epoch in range(num_epoch):
    # 예측값
    predict = w1*x1 + w2*x2 + w3*x3 + bias
    
    #error값
    error = ((predict - y)**2).mean()
    
    # 가중치 업데이트
    w1 = w1 - 2*learning_rate*((predict - y)*x1).mean()
    w2 = w2 - 2*learning_rate*((predict - y)*x2).mean()
    w3 = w3 - 2*learning_rate*((predict - y)*x3).mean()
    bias = bias - 2*learning_rate * (predict - y).mean()
    
    if epoch%10 == 0:        
        print("epoch", epoch, "w1= ", w1 , "w2= ", w2, "w3= ", w3,"bias= ", bias, "error= ", error)
        
    if error < 0.000001:
        break
        
print("최종: ","w1= ", w1 , "w2= ", w2, "w3= ", w3,"bias= ", bias, "error= ", error)
epoch 0 w1=  0.8414802196312783 w2=  0.9725817745318708 w3=  0.6743666063413661 bias=  0.645337083138217 error=  0.2749637656622825
epoch 10 w1=  0.5410384318880622 w2=  0.6906298976033656 w3=  0.6475971496205929 bias=  0.6064607103115554 error=  0.012933780944740672
epoch 20 w1=  0.42467042529868154 w2=  0.590616625261648 w3=  0.6484913842648246 bias=  0.6692268664152422 error=  0.0026111574548648886
epoch 30 w1=  0.36880921739562816 w2=  0.5466175641221471 w3=  0.6407537027257891 bias=  0.7183887305929277 error=  0.000837885298081488
epoch 40 w1=  0.33932715587127393 w2=  0.5251359261576123 w3=  0.6302433637013142 bias=  0.7496688619232196 error=  0.0002975820833988491
epoch 50 w1=  0.32300882470565545 w2=  0.51402961959147 w3=  0.6209672609151602 bias=  0.7689431978108142 error=  0.00011044923603828749
epoch 60 w1=  0.31369463359644606 w2=  0.5080484013802875 w3=  0.613979300841062 bias=  0.7807925833350058 error=  4.192064490918281e-05
epoch 70 w1=  0.30825567768545603 w2=  0.5047184225609904 w3=  0.6091022234991438 bias=  0.7880971115600803 error=  1.6088697806350452e-05
epoch 80 w1=  0.3050238365898425 w2=  0.5028128094865654 w3=  0.6058387645496128 bias=  0.7926128075135321 error=  6.208145385296592e-06
epoch 90 w1=  0.3030780445621644 w2=  0.501697916722552 w3=  0.6037092142590842 bias=  0.7954105521283038 error=  2.401769405330152e-06
epoch 100 w1=  0.30189508417942634 w2=  0.5010343471657573 w3=  0.6023412876830412 bias=  0.7971466373100539 error=  9.303361135327641e-07
최종:  w1=  0.30189508417942634 w2=  0.5010343471657573 w3=  0.6023412876830412 bias=  0.7971466373100539 error=  9.303361135327641e-07

100번의 epoch로 errorr값이 0.00001보다 작게 되어 합습이 종료되었다.

print("구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8")
print(f"BGD 최종 다항식: Y={w1}X1+{w2}X2+{w3}X3+{bias}")
구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8
BGD 최종 다항식: Y=0.30189508417942634X1+0.5010343471657573X2+0.6023412876830412X3+0.7971466373100539

경사하강법을 통해 가중치가 구하고자하는 다항식에 거의 근접했음을 알 수 있다.

SGD(Stochastic Gradient Descent)

확률적 경사하강법 이라고 부르며 경사하강법은 모든 데이터셋을 전부 사용해 가중치의 값을 엡데이트 하는 반면 확률적 경사 하강법은 매 step에서 딱 한개의 샘플을 무작위로 선택하고 그 하나의 샘플에 대해 경사하강법을 수행한다.

  • 특징:
    하나의 데이터로 weight값을 수행해나가기때문에 속도가 매우 빠르다
    하나의 데이터로 수행하기 때문에 큰 데이터셋에서도 학습가능
    반복이 충분하면 효과가 좋지만 노이즈가 매우 심하다
    속도가 매우 빠르고 메모리를 적게 사용한다는 장점이 있으나. 경사하강법 시행중에 loss값이 불안정하게 요동치므로 local minimunm에서는 탈출하기 쉽지만 Global minimum에는 다다르기 힘들다는 단점을 가지고 있다.
# feature 값 생성
np.random.seed(2)
x1 = np.random.rand(100)
x2 = np.random.rand(100)
x3 = np.random.rand(100)

#  다항식 정의(도출하고자 하는 값)
y = 0.3*x1 + 0.5*x2 + 0.6*x3 + 0.8

#임의의 weight 값 생성
w1 = np.random.uniform(low=-1.0, high=1.0)
w2 = np.random.uniform(low=-1.0, high=1.0)
w3 = np.random.uniform(low=-1.0, high=1.0)

# bias값 생성
bias = np.random.uniform(low=-1.0, high=1.0)
print("구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8")
print(f"SGD 시작 다항식: Y={w1}X1{w2}X2{w3}X3{bias}")
구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8
SGD 시작 다항식: Y=-0.9417013643049765X1-0.2542142319765477X2-0.8082533706144557X3-0.606071359299923
num_epoch=5000
learning_rate=0.5


for epoch in range(num_epoch):
    
    #매 epoch마다 하나의 데이터를 골라옴
    x1_sgd = np.random.choice(x1)
    x2_sgd = np.random.choice(x2)
    x3_sgd = np.random.choice(x3)
    y_sgd = 0.3*x1_sgd + 0.5*x2_sgd + 0.6*x3_sgd + 0.8
    
    # 데이타 한건에 대한 예측값
    predict_sgd = w1*x1_sgd + w2*x2_sgd + w3*x3_sgd + bias
    
    # 가중치 업데이트
    w1 = w1 - 2*learning_rate*((predict_sgd - y_sgd)*x1_sgd)
    w2 = w2 - 2*learning_rate*((predict_sgd - y_sgd)*x2_sgd)
    w3 = w3 - 2*learning_rate*((predict_sgd - y_sgd)*x3_sgd)
    bias = bias - 2*learning_rate * (predict_sgd - y_sgd)
    
    #error값은 전체 데이터셋의 오류값을 계산해야한다.
    predict = w1*x1 + w2*x2 + w3 + bias
    error = ((y - predict)**2).mean()
    
    if epoch%1000 == 0:
        print("epoch ", epoch,"w1= ", w1, "w2= ", w2, "w3= ", w3,"bias= ", bias, "error= ", error)
        
    if error < 0.000001:
        break
print("최종: ","w1= ", w1 , "w2= ", w2, "w3= ", w3,"bias= ", bias, "error= ", error)
epoch  0 w1=  1.411727729464638 w2=  0.2440718850889772 w3=  0.3457711551694638 bias=  2.4245142537958806 error=  4.2629318791864605
epoch  1000 w1=  0.327931719852115 w2=  0.46594023518377287 w3=  0.7199058179077448 bias=  0.7634963214793907 error=  0.15905679357559088
epoch  2000 w1=  0.3000038502144915 w2=  0.49999743045635253 w3=  0.5999991602902495 bias=  0.7999993951459968 error=  0.10748637933110514
epoch  3000 w1=  0.2999999907162555 w2=  0.5000000227172846 w3=  0.60000034490028 bias=  0.8000002725147066 error=  0.10748720663244732
epoch  4000 w1=  0.30000000004459915 w2=  0.5000000000016254 w3=  0.5999999999642428 bias=  0.8000000000421799 error=  0.10748685940927881
최종:  w1=  0.29999999999841404 w2=  0.4999999999985717 w3=  0.5999999999990189 bias=  0.8000000000016906 error=  0.10748685939390738
print("구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8")
print(f"SGD 최종 다항식:Y={w1}X1+{w2}X+2{w3}+X3+{bias}")
구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8
SGD 최종 다항식:Y=0.29999999999841404X1+0.4999999999985717X+20.5999999999990189+X3+0.8000000000016906

Mini-Batch GD:

일반적으로 Mini-Batch GD가 대부분의 딥러닝 Framework에서 채택된다.
일반적으로 확률적 경사하강법(SGD) 는 SGD with minibatch 이다.
미니배치에서 배치사이즈를 키우면 빨리되는 경향이 있지만 배치사이즈는 성능과 컴퓨팅 자원과 관련이 있다.

예를 들면 학습 데이터가 100개이고, batch size를 10으로 잡았다고 할 때, 총 10개의 minibatch가 나오게 된다. 이 mini batch를 하나당 한번씩 경사하강법을 진행하므로, 1epopch 당 10번(iteration)즉 10번의 경사하강법을 실행한다. 즉, epoch수를 x라고 한다면 학습하는 동안 총 100x 번의 경사하강법을 진행한다.
아래의 도표는 한번의 epioch가 실행될때를 도식으로 나타낸것이다. 아래와같은 과정이 총 10번 반복된다고 생각하면 된다.

  • epoch 한번 수행시
# feature 값 생성
np.random.seed(3)
x1 = np.random.rand(100)
x2 = np.random.rand(100)
x3 = np.random.rand(100)

#  다항식 정의(도출하고자 하는 값)
y = 0.3*x1 + 0.5*x2 + 0.6*x3 + 0.8

#임의의 weight 값 생성
w1 = np.random.uniform(low=-1.0, high=1.0)
w2 = np.random.uniform(low=-1.0, high=1.0)
w3 = np.random.uniform(low=-1.0, high=1.0)

# bias값 생성
bias = np.random.uniform(low=-1.0, high=1.0)
print("구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8")
print(f"Mini-Batch 시작 다항식: Y={w1}X1{w2}X2+{w3}X3+{bias}")
구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8
Mini-Batch 시작 다항식: Y=0.9653749852911324X1-0.8318386378466374X2+0.2624242248671609X3+0.15394979076289572
num_epoch=10
learning_rate=0.5


for epoch in range(num_epoch):
    print("epoch: ", epoch)
    print("-"*50)
    batch_size = 10
    batch_number = 100/10
    start = 0
    end = 10
    for iteration in range(int(batch_number)):
    
        #매 iteration마다 batch_size=10에 해당하는 데이터 셋을 가져옴
        x1_batch = x1[start: end]
        x2_batch = x2[start: end]
        x3_batch = x3[start: end]
        y_batch = 0.3*x1_batch + 0.5*x2_batch + 0.6*x3_batch + 0.8
        
        start += 10
        end += 10

        # 선택한 batch의 예측값
        predict_batch = w1*x1_batch + w2*x2_batch + w3*x3_batch + bias

        # 가중치 업데이트
        w1 = w1 - 2*learning_rate*((predict_batch - y_batch)*x1_batch).mean()
        w2 = w2 - 2*learning_rate*((predict_batch - y_batch)*x2_batch).mean()
        w3 = w3 - 2*learning_rate*((predict_batch - y_batch)*x3_batch).mean()
        bias = bias - 2*learning_rate * (predict_batch - y_batch).mean()

        #error값은 전체 데이터셋의 오류값을 계산해야한다.
        predict = w1*x1 + w2*x2 + w3*x3 + bias
        error = ((y - predict)**2).mean()

        print("iteration ", iteration,"w1= ", w1, "w2= ", w2, "w3= ", w3,"bias= ", bias, "error= ", error)

    if error < 0.000001:
            break
print("최종: ","w1= ", w1 , "w2= ", w2, "w3= ", w3,"bias= ", bias, "error= ", error)
epoch:  0
--------------------------------------------------
iteration  0 w1=  1.5213631310572862 w2=  0.09487689227100216 w3=  0.747107122844944 bias=  1.4530564406334936 error=  1.3161185764743437
iteration  1 w1=  1.0458988054279694 w2=  -0.22184149919843177 w3=  0.4477582393063155 bias=  0.41885895236392034 error=  0.29697682175683815
iteration  2 w1=  1.1956115498453805 w2=  0.0341397227800988 w3=  0.6532672791920748 bias=  0.8419753256124854 error=  0.13363176006101157
iteration  3 w1=  0.9432079516490681 w2=  -0.11588904028745653 w3=  0.47832488947898955 bias=  0.49429021169728954 error=  0.196094266179689
iteration  4 w1=  1.1153004439333605 w2=  0.14513285244713683 w3=  0.6263585922966504 bias=  0.9090625548193434 error=  0.16034485257861225
iteration  5 w1=  0.9267873746146847 w2=  -0.0018166971919065245 w3=  0.4413727303829379 bias=  0.5873212546734535 error=  0.10654049376689798
iteration  6 w1=  1.0280244644813714 w2=  0.25510881023609167 w3=  0.6099609239385889 bias=  0.9312463824393413 error=  0.1669583196666408
iteration  7 w1=  0.8701421102612867 w2=  0.12907604455089328 w3=  0.47449358784941764 bias=  0.6434268230076291 error=  0.05013452055679447
iteration  8 w1=  0.9010382914855323 w2=  0.2272115492335933 w3=  0.5408588637903134 bias=  0.765150928334827 error=  0.03608006058956387
iteration  9 w1=  0.7711389680777373 w2=  0.15885260841582477 w3=  0.49157595362075324 bias=  0.5925257980283598 error=  0.06742319763327899
epoch:  1
--------------------------------------------------
iteration  0 w1=  0.8454517116174212 w2=  0.33449180008805135 w3=  0.5970857273507987 bias=  0.8342818318695143 error=  0.06480647800893607
iteration  1 w1=  0.7395650210507694 w2=  0.2883284449486623 w3=  0.5487277283473019 bias=  0.6447621093861559 error=  0.02198627972390295
iteration  2 w1=  0.7425049885493483 w2=  0.33898513228611765 w3=  0.5828272510183393 bias=  0.7104373079883916 error=  0.015502914619954946
iteration  3 w1=  0.6705024346626554 w2=  0.30388280491314956 w3=  0.5435641109007101 bias=  0.6350560271848237 error=  0.02515350599680771
iteration  4 w1=  0.7215220419283798 w2=  0.38519951804857344 w3=  0.5914777708663186 bias=  0.7695072360646619 error=  0.02416439156750164
iteration  5 w1=  0.652783511773364 w2=  0.3362351860339089 w3=  0.5376546362701057 bias=  0.6675488873573168 error=  0.016216368705809073
iteration  6 w1=  0.6797055925447867 w2=  0.42592935115397057 w3=  0.5948609166326014 bias=  0.7861054194581176 error=  0.026015761109141448
iteration  7 w1=  0.6200378161666853 w2=  0.38668911472700324 w3=  0.5539861010003626 bias=  0.697104791465922 error=  0.00858431758571909
iteration  8 w1=  0.6037665682455545 w2=  0.40061181855540756 w3=  0.5566327739252477 bias=  0.7101150941638711 error=  0.007083468254761236
iteration  9 w1=  0.5731524794016878 w2=  0.3912607198435063 w3=  0.5590058790722068 bias=  0.6820405479751077 error=  0.009696578534389105
epoch:  2
--------------------------------------------------
iteration  0 w1=  0.5848780889347298 w2=  0.44350057501223666 w3=  0.5965650268487085 bias=  0.7556566657476836 error=  0.009279205526315915
iteration  1 w1=  0.550653084972987 w2=  0.4358996985838014 w3=  0.5871194930680333 bias=  0.7088681086806727 error=  0.004485949716737526
iteration  2 w1=  0.5383703654836287 w2=  0.4464329747286647 w3=  0.5915872625598915 bias=  0.7158045577381684 error=  0.003905803908577837
iteration  3 w1=  0.5094407769703895 w2=  0.433417941374085 w3=  0.5793577439296732 bias=  0.6947022237193463 error=  0.005591404183120755
iteration  4 w1=  0.5301019237754273 w2=  0.46713654095957124 w3=  0.6001306257465684 bias=  0.7537827655953246 error=  0.005763936337064908
iteration  5 w1=  0.49744148277830363 w2=  0.44351008085367544 w3=  0.5774407620218887 bias=  0.7092190761358498 error=  0.004088750584043604
iteration  6 w1=  0.5079451686719422 w2=  0.48403636699116354 w3=  0.6039161051044445 bias=  0.7637494960239167 error=  0.006120141043288353
iteration  7 w1=  0.4802489252049987 w2=  0.46748853309663957 w3=  0.587490550882961 bias=  0.7283878019895988 error=  0.0022766916146469613
iteration  8 w1=  0.4629037227380849 w2=  0.4648899963498438 w3=  0.5806864200426577 bias=  0.7226220683761692 error=  0.002555391078503997
iteration  9 w1=  0.45768951094434085 w2=  0.4672669245681476 w3=  0.5899792644737364 bias=  0.7260616680253269 error=  0.0021050842067980306
epoch:  3
--------------------------------------------------
iteration  0 w1=  0.4562111444068244 w2=  0.4835588028887802 w3=  0.6049550458293669 bias=  0.7505707945118762 error=  0.001992000521655735
iteration  1 w1=  0.4448777465482362 w2=  0.48480875367560566 w3=  0.6047762108133216 bias=  0.7428311295431409 error=  0.0014379461205907646
iteration  2 w1=  0.43276072122572096 w2=  0.4841459366720257 w3=  0.6016285486848757 bias=  0.7358146696333652 error=  0.0012365521189060672
iteration  3 w1=  0.42041731257393 w2=  0.4791516929562892 w3=  0.5978764360080935 bias=  0.7317362712598842 error=  0.0014754461667947857
iteration  4 w1=  0.4294025627929673 w2=  0.4943762206496298 w3=  0.6071804607797029 bias=  0.7599261154076489 error=  0.001613712170699518
iteration  5 w1=  0.4124265928937958 w2=  0.481590591923182 w3=  0.5958206206610791 bias=  0.7379224418902441 error=  0.0012392630898304684
iteration  6 w1=  0.4172837187544379 w2=  0.5021300040309805 w3=  0.6094395715167995 bias=  0.7660388376415377 error=  0.0016943419609048128
iteration  7 w1=  0.4033866392535884 w2=  0.4942996970922932 w3=  0.6019340506643328 bias=  0.7506864590984135 error=  0.0007126826777845445
iteration  8 w1=  0.3901317101435054 w2=  0.4887505807927393 w3=  0.5945407320625102 bias=  0.7421821825665166 error=  0.0010757551348840471
iteration  9 w1=  0.39266732646350955 w2=  0.4936305594051148 w3=  0.6032041844228768 bias=  0.7527886970359909 error=  0.000590138273258922
epoch:  4
--------------------------------------------------
iteration  0 w1=  0.3884579324640889 w2=  0.4973662888746008 w3=  0.6090335477855306 bias=  0.7596243307134971 error=  0.0005451184000153989
iteration  1 w1=  0.38547675966710376 w2=  0.5005528999020429 w3=  0.6110244340495393 bias=  0.7634019178642328 error=  0.000587536497691235
iteration  2 w1=  0.37594970260265836 w2=  0.4970297462806601 w3=  0.6063333304275551 bias=  0.7540412219206675 error=  0.000448851412352798
iteration  3 w1=  0.3707832048621911 w2=  0.4954148884533983 w3=  0.6055458491664456 bias=  0.7555584042541912 error=  0.00043783255510137224
iteration  4 w1=  0.37473332775943796 w2=  0.502574428345669 w3=  0.6095888562606845 bias=  0.7695162940649686 error=  0.0005022741167247502
iteration  5 w1=  0.36542532715872583 w2=  0.49524623165484083 w3=  0.6032362599374157 bias=  0.7578098978436048 error=  0.00042200285397762623
iteration  6 w1=  0.36794114090507807 w2=  0.5064998015295508 w3=  0.6106791652332865 bias=  0.773380513298299 error=  0.0005250146645868474
iteration  7 w1=  0.3605973785959103 w2=  0.5025154640472701 w3=  0.6069142493559617 bias=  0.766349821595406 error=  0.00024808074778589996
iteration  8 w1=  0.3512614219316348 w2=  0.4974418676068017 w3=  0.6008502792046014 bias=  0.7589299774708002 error=  0.00047795135974336824
iteration  9 w1=  0.3556104566872599 w2=  0.5021531947186889 w3=  0.6075334566350277 bias=  0.7695365479955073 error=  0.000210074888854553
epoch:  5
--------------------------------------------------
iteration  0 w1=  0.3515432119770179 w2=  0.5016290220644641 w3=  0.6095567394973935 bias=  0.7700297620195486 error=  0.00018147387356110003
iteration  1 w1=  0.3515250400639361 w2=  0.5047289209338258 w3=  0.6116803203742802 bias=  0.7762728473238462 error=  0.0002662727260368899
iteration  2 w1=  0.3445751721162462 w2=  0.5010513819829917 w3=  0.6073728159008093 bias=  0.7679844441825787 error=  0.0001788498385891955
iteration  3 w1=  0.34255272097523387 w2=  0.5008136570020639 w3=  0.6075678898024862 bias=  0.7709155873160385 error=  0.0001452602516688597
iteration  4 w1=  0.34426825073560974 w2=  0.5042840663532899 w3=  0.6092084035920631 bias=  0.7780274523443125 error=  0.00017096506852182426
iteration  5 w1=  0.33894472361842964 w2=  0.4999101922203242 w3=  0.6053778114244737 bias=  0.7714299994858942 error=  0.00015687329535805851
iteration  6 w1=  0.34036714003966095 w2=  0.5064496089738819 w3=  0.6096385187094396 bias=  0.7805227217489586 error=  0.0001785614467830967
iteration  7 w1=  0.33631457805971504 w2=  0.5042987143775987 w3=  0.607596511390503 bias=  0.7771556608735026 error=  9.344587106717807e-05
iteration  8 w1=  0.3299358502730216 w2=  0.5003894983297937 w3=  0.6030658482122583 bias=  0.7715007441043807 error=  0.0002164525709642465
iteration  9 w1=  0.334066167126592 w2=  0.5042107846968212 w3=  0.6078821386305095 bias=  0.7801642018003087 error=  8.851893868571426e-05
epoch:  6
--------------------------------------------------
iteration  0 w1=  0.33082810177129 w2=  0.5025228494302167 w3=  0.6083664101765803 bias=  0.77867197453337 error=  6.954785944409986e-05
iteration  1 w1=  0.3317008615799174 w2=  0.5050181334284027 w3=  0.6101007131437801 bias=  0.7844693779267784 error=  0.00012466032632983375
iteration  2 w1=  0.32680077495197735 w2=  0.5019656692873344 w3=  0.6066706295938251 bias=  0.7780033470026537 error=  7.562049938821151e-05
iteration  3 w1=  0.3261277312616293 w2=  0.5022220786361963 w3=  0.6071175468120947 bias=  0.780858465431688 error=  5.3290314776405875e-05
iteration  4 w1=  0.3268463705803736 w2=  0.5039492347948482 w3=  0.6076923775953631 bias=  0.7845806234781898 error=  6.276823259776251e-05
iteration  5 w1=  0.3236937916800033 w2=  0.5012576262619786 w3=  0.6052707182050833 bias=  0.7806911696110325 error=  6.224002805246899e-05
iteration  6 w1=  0.3245502523114995 w2=  0.5052252419664454 w3=  0.6078020675759243 bias=  0.7862143645860737 error=  6.554889193844796e-05
iteration  7 w1=  0.3222284677393136 w2=  0.5040056804642138 w3=  0.6066237164786794 bias=  0.7845353883359513 error=  3.733242026749468e-05
iteration  8 w1=  0.3179229911689899 w2=  0.5011842995918198 w3=  0.6033802343627149 bias=  0.7804594607226464 error=  9.867657135799378e-05
iteration  9 w1=  0.3212646503453425 w2=  0.5040604617092274 w3=  0.6067461817727249 bias=  0.7869835053599603 error=  4.052047722438344e-05
epoch:  7
--------------------------------------------------
iteration  0 w1=  0.3188731250865072 w2=  0.5023188760267179 w3=  0.6066586835759467 bias=  0.7851447267007798 error=  2.903263189727366e-05
iteration  1 w1=  0.31986590949610144 w2=  0.5041800712160189 w3=  0.6079499557444156 bias=  0.7897593294296319 error=  5.865908373390358e-05
iteration  2 w1=  0.3164714880572164 w2=  0.5018624074165162 w3=  0.6053908516108462 bias=  0.7849991535668726 error=  3.3106795336507595e-05
iteration  3 w1=  0.3163483101988259 w2=  0.5022397495278806 w3=  0.6058336320441922 bias=  0.7873368961832908 error=  2.11567650035473e-05
iteration  4 w1=  0.31662655361501446 w2=  0.5031207382335314 w3=  0.6059606186297893 bias=  0.7893375254958909 error=  2.448031795083903e-05
iteration  5 w1=  0.31470552678992214 w2=  0.5014256399398072 w3=  0.6043868908469904 bias=  0.7869643468958777 error=  2.585220695349197e-05
iteration  6 w1=  0.31524305885205045 w2=  0.5039078381631902 w3=  0.6059358620681858 bias=  0.7904168939433605 error=  2.5542387709219513e-05
iteration  7 w1=  0.3138693186551669 w2=  0.5031882971121157 w3=  0.6052239626918545 bias=  0.7895480140363914 error=  1.5556718624868696e-05
iteration  8 w1=  0.31097631565781514 w2=  0.5012127903095488 w3=  0.602953789068724 bias=  0.7866895173710183 error=  4.5071296897906504e-05
iteration  9 w1=  0.3134925532904848 w2=  0.5032936389635341 w3=  0.6052702832509895 bias=  0.7914039083954553 error=  1.90875470009389e-05
epoch:  8
--------------------------------------------------
iteration  0 w1=  0.3117894260979735 w2=  0.5018399877147404 w3=  0.6050137981257616 bias=  0.789775941082389 error=  1.2690770237444523e-05
iteration  1 w1=  0.31264746500887536 w2=  0.5031759962014125 w3=  0.605935504174801 bias=  0.7932094773856826 error=  2.7498273425219755e-05
iteration  2 w1=  0.3103192333817083 w2=  0.5014930522679112 w3=  0.6040929444797333 bias=  0.789813379165579 error=  1.4774705173648586e-05
iteration  3 w1=  0.310395853407843 w2=  0.5018478907952293 w3=  0.6044565424855871 bias=  0.7915860880756415 error=  8.86988152044655e-06
iteration  4 w1=  0.3104849830909564 w2=  0.5023082690768582 w3=  0.6044141913454989 bias=  0.792690701921532 error=  9.991198785534969e-06
iteration  5 w1=  0.3092870667777352 w2=  0.5012220192583352 w3=  0.6033755563760084 bias=  0.791204847436213 error=  1.1073557918761763e-05
iteration  6 w1=  0.30963331159080054 w2=  0.5028086329055839 w3=  0.6043456206971335 bias=  0.7934081023377555 error=  1.040599724733096e-05
iteration  7 w1=  0.30879804416515483 w2=  0.5023705153568349 w3=  0.6039013216177412 bias=  0.7929426217557481 error=  6.672792838059752e-06
iteration  8 w1=  0.30685648332255117 w2=  0.5010085724841129 w3=  0.6023329316891328 bias=  0.79096742847865 error=  2.0588295083731064e-05
iteration  9 w1=  0.30867886872966577 w2=  0.502479259656466 w3=  0.6039141330985085 bias=  0.7942944957732712 error=  9.029444241137552e-06
epoch:  9
--------------------------------------------------
iteration  0 w1=  0.30748990003939447 w2=  0.5013675044232039 w3=  0.6036460812036658 bias=  0.7930129930944654 error=  5.676360982203682e-06
iteration  1 w1=  0.3081578363403339 w2=  0.5023059040352221 w3=  0.6042896893228368 bias=  0.7954781641555595 error=  1.2815608370282393e-05
iteration  2 w1=  0.3065703020618537 w2=  0.501114519923964 w3=  0.6029911597797012 bias=  0.7931001301927927 error=  6.659907103543492e-06
iteration  3 w1=  0.30669953990510673 w2=  0.5014031097944962 w3=  0.6032668131524673 bias=  0.794389835485759 error=  3.847851540098933e-06
iteration  4 w1=  0.30671217573157766 w2=  0.501649693417465 w3=  0.603175368516964 bias=  0.7950163956955345 error=  4.212814912377578e-06
iteration  5 w1=  0.3059513173089545 w2=  0.500944474495695 w3=  0.6024839651203531 bias=  0.7940680600844799 error=  4.83940517688416e-06
iteration  6 w1=  0.30617794109627466 w2=  0.5019740028212347 w3=  0.6031025150781237 bias=  0.7954951603800975 error=  4.377522047276901e-06
iteration  7 w1=  0.3056584491271767 w2=  0.5017005968052399 w3=  0.6028188939181655 bias=  0.795237536892478 error=  2.917807452544335e-06
iteration  8 w1=  0.30435524000158676 w2=  0.5007696139616787 w3=  0.6017439678314062 bias=  0.7938844063179068 error=  9.39875362343407e-06
iteration  9 w1=  0.3056446228412554 w2=  0.501793816331406 w3=  0.6028182745151217 bias=  0.7961983142068405 error=  4.252022118491616e-06
최종:  w1=  0.3056446228412554 w2=  0.501793816331406 w3=  0.6028182745151217 bias=  0.7961983142068405 error=  4.252022118491616e-06
print("구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8")
print(f"Mini-Batch 최종 다항식: Y={w1}X1+{w2}X2+{w3}X3+{bias}")
구하고자하는 다항식: Y=0.3X1+0.5X2+0.6X3+0.8
Mini-Batch 최종 다항식: Y=0.3056446228412554X1+0.501793816331406X2+0.6028182745151217X3+0.7961983142068405

10번의 epoch만에 구하고자 하는 가중치에 근접했음을 알 수 있다.

profile
study blog

0개의 댓글