이전 회사에 있었을 때, 몇 만 장의 이미지와 그에 대한 라벨링 데이터를 처리하면서 속도가 너무 낮아 애를 쓴 적이 있다. 이때부터 파이썬의 성능 개선 필요성을 느끼고 있었다.
그리고 앞으로 AI Engineer로 일하면서 python을 많이 사용하게 될텐데, 이에 Ray는 업무 효율을 높여주는 데에 도움이 많이 될 것 같다는 생각이 들었다.
파이썬에서는 Threading, Multiprocessing 그리고 Asyncio를 지원한다.
이러한 모듈에 비해 Ray는 직렬화 오버헤드 문제 없이, 간단한 코드를 추가하는 것만으로 병렬 처리를 가능하게 한다. (Ray는 윗 모듈과 완전히 다른 라이브러리가 아닌, 내부에서 threading, asyncio, multiprocessing 모두 사용하여 효율을 최적화시킨 라이브러리로 보인다.)
다음은 매개변수 n을 받아, n초 만큼 쉬고 n^2을 반환하는 함수다.
import time
def function(n):
time.sleep(n)
return n*n
print("start")
start = time.time()
results = []
for n in range(1, 6):
results.append(function(n))
print("결과값: ", results)
end = time.time()
print("실행 시간: ",end - start)
============
start
결과값: [1, 4, 9, 16, 25]
실행 시간: 15.124118328094482
>>>
이 함수에 다음과 같이 @ray.remote 데코레이터만 입히는 것으로 병렬 처리가 가능해진다. 함수에 인수를 넣을 때는 remote안에 넣어 작동하며, remote 함수를 호출하면 Object(Future 객체) Ref(공유 메모리 주소)를 반환한다. 그리고 ray.get을 통해 Object를 실행시킬 수 있다. 주의할 점은 ray.get에 넣은 Object Ref들이 가리키는 작업들이 모두 끝날 때까지 다음 코드를 실행하지 않는다. (참고: Antipattern: Processing results in submission order using ray.get)
import time
import ray
ray.init()
@ray.remote
def ray_function(n):
time.sleep(n)
return n*n
start = time.time()
print("ray start")
obj_refs = []
for n in range(1, 6):
obj_refs.append(ray_function.remote(n))
results = ray.get(obj_refs)
print("결과값: ", results)
print("실행 시간: ",time.time() - start)
============
ray start
결과값: [1, 4, 9, 16, 25]
실행 시간: 6.043393850326538
>>>
ray.wait를 사용하면 준비된 Object Ref와 그렇지 않은 Object Ref를 반환받을 수 있다. 이를 통해 get에 준비된 Object Ref를 원하는 만큼씩만 넣고 중간 결과값을 반환받을 수 있다.
import time
import ray
ray.init()
@ray.remote
def ray_function(n):
time.sleep(n)
return n*n
start = time.time()
print("ray start")
obj_refs = []
for n in range(1, 6):
obj_refs.append(ray_function.remote(n))
results = []
total_task_size = len(obj_refs)
task_time_check = time.time()
while obj_refs:
done, obj_refs = ray.wait(obj_refs)
result = ray.get(done[0])
print(f'[{total_task_size-len(ret)}/{total_task_size}] 결과값: {result}, 실행 시간: {time.time()-task_time_check}')
results.append(result)
task_time_check = time.time()
print("총 결과값: ", results)
print("총 실행 시간: ",time.time() - start)
============
ray start
[1/5] 결과값: 1, 실행 시간: 1.020899772644043
[2/5] 결과값: 4, 실행 시간: 1.0051283836364746
[3/5] 결과값: 9, 실행 시간: 0.9929628372192383
[4/5] 결과값: 16, 실행 시간: 1.0078113079071045
[5/5] 결과값: 25, 실행 시간: 2.0091397762298584
총 결과값: [1, 4, 9, 16, 25]
총 실행 시간: 6.051797389984131
>>>