Comment :
어떤 DataFrame이 있을때,
is_train
이라는 새로운 Column을 생성하여is_train
= 1이면 TrainSet,is_train
= 0이면 TestSet이 되게 만드는 Custom Function
train_ratio
입력 가능Strata
, 즉 Stratifiy할 Column 입력 가능
def train_test_split(dat, train_ratio = 0.7, strata=None, seed=741):
"""
param dat: train test split할 DataFrame
param train_ratio: Split ratio
param strata: 균등하게 분배할, 즉 계층적으로 분배할 기준 Column명, 복수개의 경우 ['column1', 'column2']
return: 입력 DataFrame에 'is_train' column을 추가
"""
# 입력 strata들이 DataFrame에 존재하는지 Check
if strata and not all(strata_ in dat.columns for strata_ in strata):
return
np.random.seed(seed)
# strata 없을 경우
if strata is None:
tr_idx = dat.sample(frac=train_ratio).index
else:
tmp = dat.copy()
tmp[strata] = tmp[strata].fillna('NaN')
# 아래의 'V1'을 DataFrame에 기본적으로 존재하는 다른 Column명으로 바꿔야함
sampling = tmp.groupby(strata).apply(lambda x: x.sample(n=int(np.floor(x['V1'].count() * train_ratio))))['V1']
tr_idx = [idx[-1] for idx in sampling.index]
# 기본으로 is_train = 0
dat['is_train'] = 0
# Strata된 인덱스 is_train = 1
dat.loc[dat.index.isin(tr_idx), 'is_train'] = 1
print("+ is_train")
if strata and not all(starta_ in dat.columns for starta_ in start):
strata
값이 존재하는데 list형태의 strata
의 요소들이 dat
의 Column에 모두 존재하지 않을 경우 if문 조건만족하여 함수 탈출 return
if starta is None:
tr_idx = dat.sample(frac=train_ratio).index
strata
가 None일 경우, 입력 DataFrame을 Sampling 후 인덱스 추출