[Pytorch] 1. Config module 만들기

장해웅·2020년 12월 17일
2

Framework 만들기

목록 보기
2/2

Introduction


어떻게 framework의 입력을 다룰 것인가

Configuration은 해외 논문 코드의 경우 argparser를 이용하는 것이 일반적이다. 그러나, 연구 환경에서는 트레이닝 된 모델의 argument를 알고 있어야만 하기 때문에, argparser를 이용하는 방법은 관리하기가 상당히 번거롭다. 또한, ablation study 등, parameter를 조금씩 조정하면서 코드를 실행해야 하는 경우(물론 cell script로 가능하지만)에 다루기가 힘드며, 궁극적으로 관리해야될 파라미터가 너무 많다.

YAML


넘 편한 YAML

나는 일반적으로 모든 parameter를 yaml 파일을 통하여 관리하는 편이다. yaml의 장점은 어떤 데이터 구조를 사람이 읽기 쉬운 구조로 저장할 수 있다는 점에 있다. 이러한 구조는 모델 parameter 저장을 쉽게하고, 내가 원하는 파라미터를 조정하기 쉽게 만든다. 또한, argument를 저장하기 위해서는 단순히 파일을 복사, 붙여넣기 하면된다!

Class: Config


Config class를 만듬에 있어서 내가 고려한 요소는 다음과 같다.

  • Dict like class
  • "덮어 씌우기 연산"이 가능해야 된다

dict like class. 공식 명칭은 모르것으나 dictionary를 마치 class처럼 사용하는 방법이다. 가장 직관적인 이해는 다음을 보자.

if __name__ == '__main__':
    dictionary = {'A':5}
    config = ConfigClass()...
    print(config.A)
    print(dictionary.A)

본래, dictionary는 class의 멤버변수처럼 접근할 수 없다. 그러나 dict like class 트릭을 이용하여 dictionary를 class 처럼 이용할 수 있다!

두번째로, 덮어씌우기 연산이라는 것은 여러개의 configuration을 담고 있는 yaml이 있을 때, 이것을 합치는 연산을 말한다. 예시는 다음과 같다.

# A.yaml
DATA_LOADER:
    model_name: 'name'
    data_path: 'previous'
    data_input: ['a','b']

# A.yaml
DATA_LOADER:
    data_path: 'others'
    data_input: ['a','b','c']
    data_new_arg: 'abc'
    
# new config
DATA_LOADER:
    model_name: 'name'
    data_path: 'others'
    data_input: ['a','b','c']
    data_new_arg: 'abc'

이러한 기능을 넣고 싶은 이유는, 이전에 말한 극도의 추상화를 위해서이다. 예로들어, argument 중 하나의 값이 바뀐다고, 모든 파일을 바꾸는 것을 원하지 않는다. 일반적으로, 전체적인 parameter는 고정된 채로, 몇몇 적은 수의 파라미터만이 수정될 것이다. 따라서, 덮어 씌우기 연산을 지원하므로써, 수정에는 닫혀있고, 확장에는 열려있는 코드가 가능하다고 생각한다.

Config.py


import os
import yaml
class Config(object):
    def __init__(self, dict_config=None):
        super().__init__()
        self.set_attribute(dict_config)

    @staticmethod
    def from_yaml(path):
        with open(path, 'r') as stream:
            return Config(yaml.load(stream, Loader=yaml.FullLoader))

    @staticmethod
    def from_dict(dict):
        return Config(dict)
        
    @staticmethod
    def get_empty():
        return Config()

    def __getattr__(self, item):
        return self.__dict__[item]

    def __setattr__(self, key, value):
        self.set_attribute({key:value})

    def set_attribute(self, dict_config):
        if dict_config is None:
            return

        for key in dict_config.keys():
            if isinstance(dict_config[key], dict):
                self.__dict__[key] = Config(dict_config[key])
            else:
                self.__dict__[key] = dict_config[key]

    def keys(self):
        return self.__dict__.keys()

    def __getitem__(self, key):
        return self.__dict__[key]

    def __setitem__(self, key, value):
        self.__dict__[key] = value

    def __delitem__(self, key):
        del self.__dict__[key]

    def __contains__(self, key):
        return key in self.__dict__

    def __len__(self):
        return len(self.__dict__)

    def __repr__(self):
        return repr(self.__dict__)

    def update(self, dict_config):
        for key in dict_config.keys():
            if key in self.__dict__.keys():
                if isinstance(dict_config[key], Config):
                    self.__dict__[key].update(dict_config[key])
                else:
                    self.__dict__[key] = dict_config[key]
            else:
                self.__setitem__(key, dict_config[key])

    @classmethod
    def extraction_dictionary(cls, config):
        out = {}
        for key in config.keys():
            if isinstance(config[key], Config):
                out[key] = cls.extraction_dictionary(config[key])
            else:
                out[key] = config[key]
        return out

작성된 코드는 다음과 같다.
하나하나 살펴보자.

    @staticmethod
    def from_yaml(path):
        with open(path, 'r') as stream:
            return Config(yaml.load(stream, Loader=yaml.FullLoader))

    @staticmethod
    def from_dict(dict):
        return Config(dict)

    @staticmethod
    def get_empty():
        return Config()

우선, 다양한 객체 생성 함수를 만들었다. 새로운 config 생성 방식이 필요하면, 그냥 함수를 늘리면 된다. 생성 모듈을 따로 클래스로 뺄까도 생각해보았으나, 사용할 때 너무 헷갈려서 수정하였다..

코드에 보면 __dict__ 변수를 많이 사용하는 것을 볼 수 있다. __dict__는 python 객체의 모든 멤버변수의 이름과 값을 담고 있다. 찍어보자.

class test(object):
    def __init__(self):
        self.a = 1
        self.b = 2

if __name__ == '__main__':
    print(test().__dict__)

out: {'a': 1, 'b': 2}

이 점을 활용하여, python의 magic method들을 오버라이딩 해주므로써 dictionary를 class처럼 사용할 수 있다! 단 중첩되는 경우(dict.A.B... 이런식으로) 약간의 처리가 필요하다. 나는 다음과 같은 재귀적 방식으로 구현했다.

    def set_attribute(self, dict_config):
        if dict_config is None:
            return

        for key in dict_config.keys():
            if isinstance(dict_config[key], dict):
                self.__dict__[key] = Config(dict_config[key])
            else:
                self.__dict__[key] = dict_config[key]

attribute를 set할 때, 만약 그 attribute가 아직도 dictionary면, 새로운 Config 객체를 생성해 리턴하도록 조정한다.

다음은 업데이트 부분이다.

    def update(self, dict_config):
        for key in dict_config.keys():
            if key in self.__dict__.keys():
                if isinstance(dict_config[key], Config):
                    self.__dict__[key].update(dict_config[key])
                else:
                    self.__dict__[key] = dict_config[key]
            else:
                self.__setitem__(key, dict_config[key])

update는 두 config를 합치는 연산이다. 만약 내 dictionray에 존재하지 않는 key가 들어온다면, 그 키를 업데이트 한다. 물론, key 내부의 값이 config 객체라면, 다시 한번 재귀적으로 초기화해준다. 만약 key가 존재하지 않는다면, 그대로 추가해준다!

Conclusion


이제 parameter를 저장할 수 있는 Config class가 완성되었다. 다음은 해당 딥 러닝 프레임워크의 전체 컨트롤을 담당하는 App class에 대해서 포스팅할 예정이다.

profile
So I leap!

1개의 댓글

comment-user-thumbnail
2021년 8월 16일

도움이 많이 됐어요! 다음 글도 이어서 해주시면 좋겠네요:)

답글 달기