FastAPI, Query Parameter Filtering

Junha Kim·2021년 4월 23일
0

FastAPI

목록 보기
6/16
post-custom-banner

저번 Query Parameter 글에서 한 분이 댓글에 질문을 남겨주셨다.

필터링 처리를 쿼리파라미터로 7개 이상으로 받게될 경우 어떻게 해야하나요?
파이썬 패킹 처리를 할 것 같은데... FastAPI에서 어떻게 하는지 궁금하네요

이 부분은 나도 궁금해서 한번 실험을 해보았다.

from fastapi import Query
my_query_parmas = {"a": Query(...), "b": Query(None)}

@app.get("/")
async def p(*args, **kwargs):
    pass

@app.get("/test")
async def p(x: dict = my_query_params):
	print(x)

path url에 따로 path 파라미터가 없고, get 요청이니 따로 명시하지 않아도 Query param으로 인식이 될 것이다.

하지만 실패!

http://0.0.0.0:8000/?x=2
{
   "detail":[
      {
         "loc":[
            "query",
            "args"
         ],
         "msg":"field required",
         "type":"value_error.missing"
      },
      {
         "loc":[
            "query",
            "kwargs"
         ],
         "msg":"field required",
         "type":"value_error.missing"
      }
   ]
}

http://0.0.0.0:8000/test?x=2
{'a': Query(Ellipsis), 'b': Query(None)}

args, kwargs 자체를 매개변수 이름으로 인식하고 해당 값이 없다고 에러를 일으킨다.

이렇게 보면, 쿼리 파라미터는 어쩔 수 없이 매개변수로 직접 이름을 명시해줘야하는 듯 하다.

하지만 쿼리 파라미터가 엄청 많아질 경우엔? 너무 불편하다.

그래서 나는 따로 filter 하는 함수, 클래스를 만들어 사용했다.

인턴하는 곳에서 Django → FastAPI 전환 작업을 했어서(현재는 pending이지만...) Django기반 url 쿼리 스트링, 예를 들어 id__in과 같은 형식에 맞춰서 코드를 짰다.

  • filter_functions.py
    import sqlalchemy as sa
    from sqlalchemy import or_
    from sqlalchemy import func as F

    def join_models(query, list_of_models):
        for _model in list_of_models:
            query = query.join(_model)
        return query

    def filter_get(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column == value)
        return query

    def filter_gte(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column >= value)
        return query

    def filter_gt(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column > value)
        return query

    def filter__lte(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column <= value)
        return query

    def filter__lt(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column < value)
        return query

    def filter__startswith(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column.like(f'{value}%'))
        return query

    def filter__icontains(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column.ilike(f'%{value}%'))
        return query

    def filter__contains(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column.like(f'%{value}%'))
        return query

    def filter__in(query, model, column_name, value_list):
        model_column = getattr(model, column_name)
        query = query.filter(model_column.in_(value_list))
        return query

    def filter__list_in(query, model, column_name, value_list):
        model_column = getattr(model, column_name)
        query = query.filter(model_column.op('&&')(value_list))
        return query

    def filter__list_icontains(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(F.array_to_string(model_column, ',').ilike(f'%{value}%'))
        return query

    def filter__exclude(query, model, column_name, value):
        model_column = getattr(model, column_name)
        query = query.filter(model_column != value)
        return query

    def filter_array_any(query, model, column_name, value):
        column_name_map = {
            'tag_id__in': 'category_tag_ids',
            'city_id': 'city_ids',
            'country_id': 'country_ids'
        }
        if type(value) == str:
            values = [int(_v) for _v in value.split(",")]
        elif type(value) == int:
            values = [value]
        else:
            values = value
        column_name = column_name_map.get(column_name, column_name)
        model_column = getattr(model, column_name)
        _filters = []
        for i in values:
            _filters.append(model_column.any(i))
        query = query.filter(or_(*_filters))
        return query

    def filter_date_range(query, model, column_name, start, end):
        model_column = getattr(model, column_name)
        query = query.filter(F.date(model_column).between(start, end))
        return query

    # TODO: should consdier like (city__code=123) and the integer/string type 
    async def query_parser(request_param_func, value=None):
        filter_function_mapping = {
            "gt": filter_gt,
            "gte": filter_gte,
            "lt": filter__lt,
            "lte": filter__lte,
            "startswith": filter__startswith,
            "icontains": filter__icontains,
            "contains": filter__contains,
            "in": filter__in,
            "&&": filter__list_in,
            "icontains_list": filter__list_icontains,
            "range": filter_date_range
        }
        field_func = request_param_func.split("__")

        if len(field_func) == 1:
    				# if query string is like "id=1,3,4" -> treat as `in` filtering
            if isinstance(value, list) or value and len(value.split(',')) > 1:
                return filter__in, field_func[0]
            return filter_get, field_func[0]

        field, func = field_func
        return filter_function_mapping[func], field

filter functions는 쿼리셋을 필터링하는 함수들이 모여있으며, 마지막 query_parser는 쿼리 스트링을 보고 해당 쿼리에 맞는 함수와 필드를 리턴해준다.

  • filter_class.py
    from app.utils import filter_functions as filter
    from sqlaclhemy.orm import Query
    import copy

    class FilterBase:
        fields = {}

        async def exec_filter(self, queryset: Query, model, request_params: dict):
            for key, value in request_params.items():
                filter_func, key = await filter.query_parser(key)
                queryset = filter_func(queryset, model, key, value)
            return queryset

        async def check_allowed_filter_func(self, request_params: dict):
            pram_func = request_params.keys()
            temp_param_func = copy.deepcopy(list(pram_func))

            for condition in temp_param_func:
                if "__" in condition:
                    pram, func = condition.split('__')
                    # param is not in allowed fields or doesn't allow the method
                    if pram not in self.fields.keys() or func not in self.fields[pram]:
                        request_params.pop(condition)

                elif condition in self.fields.keys():
                    # if func is exact query but if the func is not allowed
                    if 'exact' not in self.fields[condition]:
                        request_params.pop(condition)
                # func is exact query but param does not exist on allowed fields
                else:
                    request_params.pop(condition)
            return request_params

    class TemplateFilter(FilterBase):
        fields = {
            'id': ['in'],
            'create_dt': ['date'],
            'business_name': ['exact'],
            'channel': ['exact'],
            'template_type': ['exact', 'in']
        }

    TemplateListFilter = TemplateFilter()

filter_class에는 FilterBase라는 추상 클래스가 존재하여, 기본적인 필터링 메소드를 제공한다.

fields{"필드": ['메소드',..] 형식으로 허용할 필드와 메소드를 작성한다.

check_allowed_filter_func는 쿼리 파라미터와 현재 허용되는 필드와 메소드가 아닌 것은 제거를 한다.

exec_filter 는 그렇게 걸러진 쿼리 파라미터의 필터링을 수행하는 최종 함수이다.

그리고 밑의 TemplateFilter는 직접 활용하는 클래스 예시이다.

FilterBase를 상속받아, 본인이 허용할 필드, 메소드를 fields에 오버라이딩을 하고 export하기 위해서 객체를 생성하여 변수에 담아주면 된다.

만약 필터링에 추가적인 작업이 필요하다면 상속받은 함수를 적절히 오버라이딩을 하면 된다.

  • 활용 예시
    async def get_template_list(db: Session, request: Request, page: int, page_size: int):
        queryset = db.query(models.Template)

        query_params = dict(request.query_params)
        valid_query_params = await utils.TemplateListFilter.check_allowed_filter_func(query_params)
        queryset = await utils.TemplateListFilter.exec_filter(queryset, models.Template, valid_query_params)
        ```

request 객체에서 쿼리 파라미터를 받고 Filter 객체를 불러와서 실행만 하면 끝난다.
post-custom-banner

0개의 댓글