Train / Test Split function

olxtar·2022년 11월 27일
0

Comment :

어떤 DataFrame이 있을때, is_train 이라는 새로운 Column을 생성하여 is_train = 1이면 TrainSet, is_train = 0이면 TestSet이 되게 만드는 Custom Function

  • train_ratio 입력 가능
  • Strata, 즉 Stratifiy할 Column 입력 가능


def train_test_split(dat, train_ratio = 0.7, strata=None, seed=741):
	"""
    param dat: train test split할 DataFrame
    param train_ratio: Split ratio
    param strata: 균등하게 분배할, 즉 계층적으로 분배할 기준 Column명, 복수개의 경우 ['column1', 'column2']
    return: 입력 DataFrame에 'is_train' column을 추가
    """
    
    # 입력 strata들이 DataFrame에 존재하는지 Check
    if strata and not all(strata_ in dat.columns for strata_ in strata):
    	return
        
    np.random.seed(seed)
    
    
    # strata 없을 경우
    if strata is None:
    	tr_idx = dat.sample(frac=train_ratio).index
        
    else:
    	tmp = dat.copy()
        tmp[strata] = tmp[strata].fillna('NaN')
        # 아래의 'V1'을 DataFrame에 기본적으로 존재하는 다른 Column명으로 바꿔야함
        sampling = tmp.groupby(strata).apply(lambda x: x.sample(n=int(np.floor(x['V1'].count() * train_ratio))))['V1']
        tr_idx = [idx[-1] for idx in sampling.index]
    
    # 기본으로 is_train = 0
    dat['is_train'] = 0
    # Strata된 인덱스 is_train = 1
    dat.loc[dat.index.isin(tr_idx), 'is_train'] = 1
    
    print("+ is_train")

if strata and not all(starta_ in dat.columns for starta_ in start): 

strata 값이 존재하는데 \rightarrow list형태의 strata의 요소들이 dat의 Column에 모두 존재하지 않을 경우 if문 조건만족하여 함수 탈출 \rightarrow return



if starta is None:
	tr_idx = dat.sample(frac=train_ratio).index

strata가 None일 경우, 입력 DataFrame을 Sampling 후 인덱스 추출

profile
예술과 기술

0개의 댓글