수치해석 알고리즘이란 함수나 방정식의 해의 근삿값을 이분법이나 뉴턴-랩슨법 등을 활용해 찾는 알고리즘입니다. 인간의 힘으로 해를 구하기 힘든 함수나 방정식의 해를 구할 때 사용됩니다.
어떠한 범위 안에서 중간값을 기준으로 방정식이라면 0보다 크다면 중간값보다 작은 범위에서 다시 탐색, 작다면 중간값보다 큰 범위에서 다시 탐색하는 방식으로 작동합니다.
아래는 간단히 작성한 이분법의 코드입니다.
import sys
from decimal import Decimal
input = sys.stdin.readline
a, b, c, d, e = map(Decimal, input().split())
def func(x):
return a*(x**3) + b*(x**2) + c*x + d - e
start, end = Decimal(-10**10), Decimal(10**10) # 시작 범위
for _ in range(1000):
mid = (start+end) / Decimal(2.0)
if func(mid) < 0:
start = mid
else:
end = mid
print(mid)
입력 : 3 2 1 0 27 (3x^3+2x^2+x = 27)
출력 : 1.832252860066851278868305314
아래는 https://ko.numberempire.com/equationsolver.php 에서 연산을 실행한 결과인데, 해를 성공적으로 구해냈음을 알 수 있습니다.

import sys
import decimal
import math
pi = "3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461284756482337867831652712019091456485669234603486104543266482133936072602491412737245870066063155881748815209209628292540917153643678925903600113305305488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798609437027705392171762931767523846748184676694051320005681271452635608277857713427577896091736371787214684409012249534301465495853710507922796892589235420199561121290219608640344181598136297747713099605187072113499999983729780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083814206171776691473035982534904287554687311595628638823537875937519577818577805321712268066130019278766111959092164201989"
input = sys.stdin.readline
decimal.getcontext().prec = 50
a, b, c = map(decimal.Decimal, input().split())
fac = [decimal.Decimal(1), decimal.Decimal(1)]
for i in range(2, 75):
fac.append(fac[-1]*decimal.Decimal(i))
def sin(x):
result = x
idx = 0
for i in range(3, 75, 2): # 테일러 급수로 sin 함수 구현
term = pow(x, decimal.Decimal(i))/fac[i]
if abs(term) < decimal.Decimal('1e-50'):
break
if idx%2 == 0:
result -= term
else:
result += term
idx += 1
return result
def f(x):
return a * x + b * decimal.Decimal(sin(x % (decimal.Decimal(2)*decimal.Decimal(pi)))) # 사인파 특성상 2*pi 를 넘어가면 반복되기 때문에 나머지 연산
start, end = decimal.Decimal(-10**10), decimal.Decimal(10**10)
for i in range(1000):
mid = (start+end)/decimal.Decimal(2)
res = f(mid)
if res > c:
end = mid
else:
start = mid
res = mid.quantize(decimal.Decimal('0.000001'), decimal.ROUND_HALF_EVEN) # 반올림
print(res)
다음과 같은 방식으로 해를 찾는 알고리즘입니다. 주로 비선형 방정식을 푸는데 사용되고, 해가 여러개인 경우에는 초기값에서 가장 가까운 해를 찾아줍니다.
import sys
from decimal import Decimal
input = sys.stdin.readline
a, b, c, d, e = map(Decimal, input().split())
def func(x):
return a*(x**3) + b*(x**2) + c*x + d - e
def df(x): # 도함수
return 3*a*(x**2) + 2*b*x + c
x = Decimal(2)
for _ in range(1000):
x = x - func(x)/df(x)
print(x)
입력 : 3 2 1 0 27
출력 : 1.832252860066851278868305314
import sys
import decimal
import math
pi = "3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461284756482337867831652712019091456485669234603486104543266482133936072602491412737245870066063155881748815209209628292540917153643678925903600113305305488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798609437027705392171762931767523846748184676694051320005681271452635608277857713427577896091736371787214684409012249534301465495853710507922796892589235420199561121290219608640344181598136297747713099605187072113499999983729780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083814206171776691473035982534904287554687311595628638823537875937519577818577805321712268066130019278766111959092164201989"
input = sys.stdin.readline
decimal.getcontext().prec = 50
a, b, c = map(decimal.Decimal, input().split())
fac = [decimal.Decimal(1), decimal.Decimal(1)]
for i in range(2, 75):
fac.append(fac[-1]*decimal.Decimal(i))
def cos(x):
result = 1
idx = 0
for i in range(2, 75, 2):
term = pow(x, decimal.Decimal(i))/fac[i]
if abs(term) < decimal.Decimal('1e-50'):
break
if idx % 2 == 0:
result -= term
else:
result += term
idx += 1
return result
def sin(x):
result = x
idx = 0
for i in range(3, 75, 2):
term = pow(x, decimal.Decimal(i))/fac[i]
if abs(term) < decimal.Decimal('1e-50'):
break
if idx % 2 == 0:
result -= term
else:
result += term
idx += 1
return result
def f(x):
return a * x + b * decimal.Decimal(sin(x % (decimal.Decimal(2)*decimal.Decimal(pi)))) - c
def df(x):
return a + b * decimal.Decimal(cos(x % (decimal.Decimal(2)*decimal.Decimal(pi))))
x = decimal.Decimal(c/a)
for i in range(1000):
x = x - f(x)/df(x)
res = x.quantize(decimal.Decimal('0.000001'), decimal.ROUND_HALF_EVEN)
print(res)