Spark를 활용하여 대용량의 데이터를 다루다보면, 필요에 의해 데이터의 샘플링을 진행할 경우가 다수 발생한다. 아래 2가지 예시 포함 다양한 이유로 샘플링을 사용한다.
샘플링을 위해 Spark에서 많이 사용되는 API는 sample(), randomSplit() 정도가 있다. 이 문서에서는 sample()에 대해서 지원하는 기능을 살펴본다.
3가지 파라미터
를 입력받을 수 있으며, 모두 optional 하다- withReplacement : 복원추출 여부. 기본설정이 False
- fraction: 추출할 sample의 비율을 지정하며, 0.0 ~ 1.0 사이의 값을 입력받음. (가장 많이 쓰이는 파라미터)
- seed: 샘플링에 활용되는 seed값. 미력입시 랜덤으로 적용되어, 출력때마다 다른 값이 나온다.
fraction 비율에 딱 맞는 sample개수
sample() 을 통해 뽑힌 df_sample은 고정된 row 들을 보유하고 있다.
# df -> id=[0,1,2,3,4,5,6,7,8,9] 10개
# df_sample -> id=[1,3] 2개
# ex. row 10개 일때 frac=0.2 이면, 항상 2개가 나와야해
df_sample = df.sample(fraction=0.2)
# ex. df_sample -> id=[1,3]이니까, tb_test에도 id=[1,3]이 담겨있을거야
# tb_test -> id=[1,3] 2개 ?
df_sample.write.saveAsTable("tb_test")
하지만, 이 API만 믿고 사용하다가는, A/B 테스트등의 샘플링을 활용한 실험에서 실수를 범할수도 있다.
unique_key
값을 생성하였고, 그 수는 약 2.75억개이다q = """
select *
from tb_test
where log_type_code='...' and date_id='2022-09-01'
"""
df = spark.sql(q)
df = df.select("time_id","user_id","item_id").distinct()
df = df.withColumn("unique_key",f.concat_ws("_","time_id","user_id","item_id"))
print(f"{df.count():,}") # 275,791,779
fraction 비율에 딱 맞는 sample개수
’는 보장되지 않음을 확인할수있다val1 = df.sample(fraction=0.1).count()
val2 = df.sample(fraction=0.1).count()
print(f"{val1:,}") # 27,573,170
print(f"{val2:,}") # 27,569,067
df_sample
을 테이블 2개에 저장함row count
도 같고, 포함된 unique_key
가 같아야 할것이라 예상해본다.########### 1. sampling 후 테이블에 저장 ###########
df_sample = df.sample(0.1,seed=10)
print(f"{df_sample.count():,}") # 27,576,224
tb_name_v1 = "user_rupert.z_tmp_spark_random_seed_test_v1"
tb_name_v2 = "user_rupert.z_tmp_spark_random_seed_test_v2"
df_sample.write.mode("overwrite").saveAsTable(tb_name_v1)
df_sample.write.mode("overwrite").saveAsTable(tb_name_v2)
df_v1 = spark.sql(f"select * from {tb_name_v1}")
df_v2 = spark.sql(f"select * from {tb_name_v2}")
########### 2. 테이블 v1, v2 비교 ###########
val1 = df_v1.count() # 27,576,224
val2 = df_v2.count() # 27,576,224
########### 2-1. 테이블 v1, v2 교차량 비교 ###########
# 차집합 (v1 - v2, v1테이블에만 있는 개수)
# 24,792,477
n_only_in_v1 = df_v1.join(df_v2,on="unique_key",how="left_anti").count()
# v1 ⋂ v2 교집합 개수
# 2,783,747
n_intersection = df_v1.join(df_v2,on="unique_key",how="left_semi").count()
# 차집합 (v2 - v1, v2테이블에만 있는 개수)
# 24,792,477
n_only_in_v2 = df_v2.join(df_v1,on="unique_key",how="left_anti").count()
fraction 비율에 딱 맞는 sample개수
는 보장되지 않는다sample에 담기는 row가 바뀔수 있다
기대값에 따르면 50번 앞이 나올 확률
’이 가장 높지만, 50번보다 더 나올수도 덜 나올수도 있음.1-1의 동전던지기
를 생각해보자. 0~1 사이의 랜덤값이 나오는 상황에서 𝜇=0.1 로 설정하고 각각의 row가 독립적인 random 값을 부여받고, 확률구간(0.0 ~ 0.1)에 포함되면 sample에 포함! 이라고 판단함ex) 앞면이 나올 확률이 𝜇=0.1 인 동전을 10명이 각자 던져서, 앞면이 나온 사람이 뽑힌다
df.sample(0.1,seed=10).explain()
# output
== Physical Plan ==
*(2) Sample 0.0, 0.1, false, 10
+- *(2) HashAggregate(keys=[time_id#6, user_id#7L, item_id#26], functions=[])
+- Exchange hashpartitioning(time_id#6, user_id#7L, item_id#26, 200)
+- *(1) HashAggregate(keys=[time_id#6, user_id#7L, item_id#26], functions=[])
+- *(1) Project [time_id#6, user_id#7L, item_id#10 AS item_id#26]
+- Scan hive tb_test [log_info#10, time_id#6, user_id#7L], HiveTableRelation `tb_test`, org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe, [time_id#6, user_id#7L, ...]
sample 0.0, 0.5, false, 11
로 plan 되어있다고 생각할 수 있음(𝜇=0.5)
https://www.waitingforcode.com/apache-spark-sql/randomsplit-implementation-apache-spark-sql/read
########### 1. sampling 후 테이블에 저장 ###########
df_sample = df.sample(0.1,seed=10)
print(f"{df_sample.count():,}") # 27,576,224
df_sample.write.mode("overwrite").saveAsTable("user_rupert.z_tmp_spark_random_seed_test_v1")
df_sample.write.mode("overwrite").saveAsTable("user_rupert.z_tmp_spark_random_seed_test_v2")
df_v1 = spark.sql("select * from user_rupert.z_tmp_spark_random_seed_test_v1")
df_v2 = spark.sql("select * from user_rupert.z_tmp_spark_random_seed_test_v2")
import pyspark.sql.functions as f
df.sample(0.15,seed=10).sort(f.rand(seed=10)).limit(1_000_000)
########## 방법 1 ###########
df_sample = df.sample(0.1,seed=10).cache() # 한번 action이 되면, df_sample값이 고정됨
df_sample.count() # 단순히 action 수행목적
df_v1 = df_sample
df_v2 = df_sample
########## 방법 2 ###########
df_sample = df.sample(0.1,seed=10)
df_sample.write.mode("overwrite").saveAsTable("user_rupert.v1")
df_sample = spark.sql("select * from user_rupert.v1") # 파일로 쓰고, 이후엔 그걸 사용함
df_v1 = df_sample
df_v2 = df_sample
row_number()
사용할때 발생할 수 있음