회사에서 대용량의 데이터를 정제해 하나의 파일(gzip 형식)로 저장해야할 일이 생겼다.
데이터를 정제해서 새로운 테이블로 만드는 건 익숙한 일이라 고민없이 작업을 하고 있었는데, 왠걸 이렇게 대용량의 데이터를 처리해본 적이 없어서 기존에 무지성으로 작성한 코드가 어마무시하게 리소스를 잡아먹어버리는 문제가 발생했다.
예시로 발생한 문제는 아래처럼 되어있다.
(예시) 데이터프레임의 칼럼 group_id마다 "rank=1"이라는 조건을 가지는 row에 대해 해당 group_id의 price를 모두 더한 값을 새로운 칼럼 sum_price으로 group_id에 맞게 추가해 데이터프레임을 바꾸고 싶다.
평소같으면 아래와 같이 코드를 무지성으로 짰을 것이다.
from pyspark.sql import functions as f
df2 = df1.filter(
f.col('rank')==1
).groupBy(
f.col('group_id')
).agg(
f.sum(f.col('price')).alias('sum_price')
).select(
f.col('group_id'),
f.col('sum_price')
)
res_df = df1.join(
df2, on = ['group_id'], how = 'inner'
)
데이터프레임이 작으면 전혀 문제없이 넘어갈 수 있는 코드이다. 그런데 만약 group_id가 무지막지하게 다양하다면? join연산하는데 엄청나게 많은 리소스가 들 것이다.(왜 그런거지?) 이러한 문제를 Window를 통해서 해소할 수 있다는 것을 알게 되었다.
그래서 구글링을 해보니 Window라고 하는 메소드가 있다는 것을 알게 되었다. Window는 "그룹된 행을 대상으로" 정해준 작업을 수행해 결과를 리턴해주는 객체이다. Window를 사용할 때 어떠한 파티션으로 작업을 수행할지 정의할 수 있다. 예를 들면, 위의 문제에 대해서 우리는 group_id가 같은 row들끼리만 비교하면 되니까 group_id를 파티션 기준으로 설정할 것이다. 그러면 Window는 group_id가 같은 row들을 조그만 파티션으로 만들어서 그 안에서 연산을 수행할 것이다. 찾아보니 정렬을 해서 row 넘버를 붙이고 싶다면 orderBy()를 사용할 수도 있다.
Window.partitionBy("group_id")
Window는 over()라는 메소드와 함께 사용하는데, over는 인자로 Window를 받고 수행하려는 작업뒤에 붙어서 사용하게 된다. 즉, 내가 만약 col1에 대해서 파티션을 나누고 (그룹화를 하고)
해당 파티션 내에서 col2에 대한 max값만을 추출하고 싶다면 아래와 같이 사용하면 된다.
partition_col1 = Window.partitionBy(f.col('col1'))
# new_col에 원하는 값을 넣는다.
df.withColumn(
'new_col', f.max(f.col('col2')).over(partition_col1)
)
그리하야 Window를 사용해서 리소스를 좀 덜 먹는 코드는 아래와 같다. 우리는 rank=1인 row들의 price의 합만 구할 것이니까 f.expr()를 통해서 조건을 걸어야한다. 실제로 차이는 좀 많이 나던데 아무래도 데이터가 무지막지하게 크다보니 오래 걸리긴 하더라..
from pyspark.sql import functions as f
from pyspark.sql.window import Window
partition_group_id = Window.partitionBy("group_id")
res_df = df.withColumn(
'sum_price', f.sum(
f.expr("case when rank=1 then price else 0 end")
).over(partition_group_id)
)
여기서 끝내면 그냥 Window를 쓰고 끝나는 포스팅이다. 그런데 내부 동작 방식에 대한 이해를 하려고 이 블로그를 시작한 것이기 때문에 좀 깊게 파보자.
stackoverflow에서는 explain를 사용해서 자세하게 설명을 해준다.
우선 첫 번째로 Window는 [스캔 -> 리파티셔닝 -> 소팅 -> 윈도우]의 physical plan으로 동작을 하게 된다.
다음으로 join을 하게 되면 분기점이 생기는데 각 분기점에서 [스캔 -> 리파티셔닝 -> 소팅]을 진행하게 되고 이 두 RDD를 SortMergeJoin을 한다. 여기서 cardinality에 따라서 차이가 크게 날 것으로 예상할 수 있다. cardinality는 집합의 크기라고 하는데, 조인할 키값으로 이해하면 될 것이고 이 예시에서는 유니크한 group_id의 크기로 보면 될 것이다. 맨 앞에서 말한 것처럼 cardinality가 작은 작업에서는 브로드캐스트조인이 무시할 수준이어서 윈도우 연산보다는 적게 걸리지만 이번처럼 ad_group_id가 무지막지하게 다양한 경우라면? 브로드캐스트조인이 소트머지조인으로 바뀌면서 엄청나게 많은 리소스를 요구할 것이다.(메모리에 올리기에 용량이 너무 크기 때문이다.) 그래서 걸리는 시간에 차이가 많이 났던 것이다.