수치해석

Noah·2024년 8월 3일

알고리즘

목록 보기
4/20

수치해석

수치해석 알고리즘이란 함수나 방정식의 해의 근삿값을 이분법이나 뉴턴-랩슨법 등을 활용해 찾는 알고리즘입니다. 인간의 힘으로 해를 구하기 힘든 함수나 방정식의 해를 구할 때 사용됩니다.

이분법

어떠한 범위 안에서 중간값을 기준으로 방정식이라면 0보다 크다면 중간값보다 작은 범위에서 다시 탐색, 작다면 중간값보다 큰 범위에서 다시 탐색하는 방식으로 작동합니다.

아래는 간단히 작성한 이분법의 코드입니다.

import sys
from decimal import Decimal

input = sys.stdin.readline
a, b, c, d, e = map(Decimal, input().split())

def func(x):
    return a*(x**3) + b*(x**2) + c*x + d - e

start, end = Decimal(-10**10), Decimal(10**10) # 시작 범위
for _ in range(1000):
    mid = (start+end) / Decimal(2.0)
    if func(mid) < 0:
        start = mid
    else:
        end = mid

print(mid)
입력 : 3 2 1 0 27 (3x^3+2x^2+x = 27)
출력 : 1.832252860066851278868305314

아래는 https://ko.numberempire.com/equationsolver.php 에서 연산을 실행한 결과인데, 해를 성공적으로 구해냈음을 알 수 있습니다.

코드 (BOJ 13705)

import sys
import decimal
import math

pi = "3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461284756482337867831652712019091456485669234603486104543266482133936072602491412737245870066063155881748815209209628292540917153643678925903600113305305488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798609437027705392171762931767523846748184676694051320005681271452635608277857713427577896091736371787214684409012249534301465495853710507922796892589235420199561121290219608640344181598136297747713099605187072113499999983729780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083814206171776691473035982534904287554687311595628638823537875937519577818577805321712268066130019278766111959092164201989"
input = sys.stdin.readline
decimal.getcontext().prec = 50
a, b, c = map(decimal.Decimal, input().split())
fac = [decimal.Decimal(1), decimal.Decimal(1)]
for i in range(2, 75):
    fac.append(fac[-1]*decimal.Decimal(i))

def sin(x):
    result = x
    idx = 0
    for i in range(3, 75, 2): # 테일러 급수로 sin 함수 구현
        term = pow(x, decimal.Decimal(i))/fac[i]
        if abs(term) < decimal.Decimal('1e-50'):
            break
        if idx%2 == 0:
            result -= term
        else:
            result += term
        idx += 1
    return result
    
def f(x):
    return a * x + b * decimal.Decimal(sin(x % (decimal.Decimal(2)*decimal.Decimal(pi)))) # 사인파 특성상 2*pi 를 넘어가면 반복되기 때문에 나머지 연산

start, end = decimal.Decimal(-10**10), decimal.Decimal(10**10)
for i in range(1000):
    mid = (start+end)/decimal.Decimal(2)
    res = f(mid)
    if res > c:
        end = mid
    else:
        start = mid
res = mid.quantize(decimal.Decimal('0.000001'), decimal.ROUND_HALF_EVEN) # 반올림
print(res)

뉴턴-랩슨법

xn+1=xf(xn)f(xn)x_{n+1} = x-\displaystyle\frac{f(x_n)}{f'(x_n)}

다음과 같은 방식으로 해를 찾는 알고리즘입니다. 주로 비선형 방정식을 푸는데 사용되고, 해가 여러개인 경우에는 초기값에서 가장 가까운 해를 찾아줍니다.

import sys
from decimal import Decimal

input = sys.stdin.readline
a, b, c, d, e = map(Decimal, input().split())

def func(x):
    return a*(x**3) + b*(x**2) + c*x + d - e

def df(x): # 도함수
    return 3*a*(x**2) + 2*b*x + c

x = Decimal(2)
for _ in range(1000):
    x = x - func(x)/df(x)

print(x)
입력 : 3 2 1 0 27
출력 : 1.832252860066851278868305314

코드 (BOJ 13705)

import sys
import decimal
import math

pi = "3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461284756482337867831652712019091456485669234603486104543266482133936072602491412737245870066063155881748815209209628292540917153643678925903600113305305488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798609437027705392171762931767523846748184676694051320005681271452635608277857713427577896091736371787214684409012249534301465495853710507922796892589235420199561121290219608640344181598136297747713099605187072113499999983729780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083814206171776691473035982534904287554687311595628638823537875937519577818577805321712268066130019278766111959092164201989"
input = sys.stdin.readline

decimal.getcontext().prec = 50
a, b, c = map(decimal.Decimal, input().split())

fac = [decimal.Decimal(1), decimal.Decimal(1)]

for i in range(2, 75):
    fac.append(fac[-1]*decimal.Decimal(i))

def cos(x):
    result = 1
    idx = 0
    for i in range(2, 75, 2):
        term = pow(x, decimal.Decimal(i))/fac[i]
        if abs(term) < decimal.Decimal('1e-50'):
            break
        if idx % 2 == 0:
            result -= term
        else:
            result += term
        idx += 1
    return result

def sin(x):
    result = x
    idx = 0
    for i in range(3, 75, 2):
        term = pow(x, decimal.Decimal(i))/fac[i]
        if abs(term) < decimal.Decimal('1e-50'):
            break
        if idx % 2 == 0:
            result -= term
        else:
            result += term
        idx += 1
    return result

def f(x):
    return a * x + b * decimal.Decimal(sin(x % (decimal.Decimal(2)*decimal.Decimal(pi)))) - c

def df(x):
    return a + b * decimal.Decimal(cos(x % (decimal.Decimal(2)*decimal.Decimal(pi))))

x = decimal.Decimal(c/a)
for i in range(1000):
    x = x - f(x)/df(x)
res = x.quantize(decimal.Decimal('0.000001'), decimal.ROUND_HALF_EVEN)
print(res)
profile
부산소프트웨어마이스터고 4기 | 자세한 내용은 홈페이지(노션)의 테크 블로그에서 확인할 수 있습니다.

0개의 댓글