flask project refactoring

유동헌·2022년 2월 6일
0

리팩토링의 목적은 다음과 같습니다.

  1. flask로 API를 구축하는 방법에 대해 복습
  2. 팀 프로젝트였었던 코드를 다시 뜯어보고 직접 참여하지 않은 기능에 대해서 공부
  3. 기존 복잡했었던 모델링을 수정하여 간단하게 바꿔보기
  4. flask-restful → flask-restx 프레임워크 바꾸기 (Swagger 문서화 작업의 편의성)

app = Flask(name) → 팩토리 함수로 바꾸기

기존 코드

from flask         import Flask
from flask_restful import Api
from flask_migrate import Migrate

from company.models     import db
from company.controller import CompanyList, GetDetailCompany, CompanyCreateView
from config             import DB_URL

app = Flask(__name__)
api = Api(app)

app.config["SQLALCHEMY_DATABASE_URI"]        = DB_URL
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["SQLALCHEMY_COMMIT_ON_TEARDOWN"]  = True

db.init_app(app)
db.app = app
db.create_all()
migrate = Migrate(app, db)

api.add_resource(CompanyList, "/search")
api.add_resource(GetDetailCompany, "/companies/<comp_name>")
api.add_resource(CompanyCreateView, '/companies')

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000, debug=False)

리팩토링

from flask import Flask
from flask_restx import Api
from flask_migrate import Migrate

from company import config
from company.controller import Company_Info
from company.models import db

migrate = Migrate()

def create_app(test_db_url=None):
    app = Flask(__name__)

    if test_db_url:
        app.config["SQLALCHEMY_DATABASE_URI"] = test_db_url
        app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
    else:
        app.config.from_object(config)
    
    db.init_app(app)
    db.app = app
    db.create_all()
    migrate.init_app(app, db)

    api = Api(app, version='1.0', title='API 문서', description='Swagger 문서', doc="/api-docs")
    api.add_namespace(Company_Info, "/companies")

    return app, db

if __name__ == '__main__':
    app = create_app()[0]
    app.run(host="0.0.0.0", port=6000, debug=True)
  1. flask-restxFlask(__name__) 사용법은 전역에 선언하는 것으로 검색 결과가 많이 나오기도 했고, flask-restx가 꾸준하게 업데이트가 되는 라이브러리로 알고 진행을 하였습니다. 하지만 pytest 진행 과정 중에서 test_client를 생성해 요청을 보내야 하는데 그 부분에서 애를 먹다가 팩토리 함수라는 개념을 알게 되어 전역 환경에 선언했던 app = Flask(__name__)을 함수로 만들어 반환해주는 create_app 함수를 만들어 처리하였습니다.

  2. 테스트 코드를 위한 부분

     def create_app(test_db_url=None):
        app = Flask(__name__)
    
        if test_db_url:
            app.config["SQLALCHEMY_DATABASE_URI"] = test_db_url
            app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        else:
            app.config.from_object(config)

    create_app() 함수가 실행될 때 만약에 test_db_url이라는 매개변수가 들어간다면, DATABASE_URI는 테스트 데이터베이스가 됩니다. 이를 위해 따로 테스트 데이터베이스를 만들어주어야 합니다.

    만약 그렇지 않다면 from_object(config)라는 문장이 실행되어 config 파일에 있는 정보를 읽어옵니다.

    from_object 옵션에 대해 flask-docs을 확인해 보았습니다.

    대부분의 어플리케이션은 하나 이상의 설정(구성)이 필요 하다. 적어도 운영 환경과 개발환경은 독립된 설정값을 가지고 있어야만 한다. 이것을 다루는 가장 쉬운 방법은 버전 관리를 통하여 항상 로드되는 기본 설정값을 사용하는 것이다. 그리고 독립된 설정을값들을 필요에 따라 위에서 언급했던 방식으로 덮어쓰기한다:

    app = Flask(__name__)
    app.config.from_object('yourapplication.default_settings')
    app.config.from_envvar('YOURAPPLICATION_SETTINGS')

    이렇게 설명이 되어 있었는데요, 그냥 object로 불러오기만 해도 설정값이 잘 적용되는 것 같았습니다!

    app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False 옵션의 경우 추가적인 메모리를 필요로 해 꺼두는 것을 추천한다고 합니다.

  3. 위의 flask 객체에 대한 설정값을 조금 더 자세하게 알아보겠습니다. flask-docs을 참고하였습니다.

    [config](https://flask-docs-kr.readthedocs.io/ko/latest/ko/api.html#flask.Flask.config)은 실제로는 dictionary 의 서브클래스이며, 다른 dictionary 처럼 다음과 같이 수정될 수 있다:

    app = Flask(__name__)
    app.config['DEBUG'] = True

    확정된 설정값들은 또한 [Flask](https://flask-docs-kr.readthedocs.io/ko/latest/ko/api.html#flask.Flask)객체로 전달될 수 있으며, 그 객체를 통해 설정값들을 읽거나 쓸수 있다. 한번에 다수의 키(key)들을 업데이트 하기 위해서는 [dict.update()](http://docs.python.org/dev/library/stdtypes.html#dict.update) 함수를 사용 할 수 있다.

    app.debug = True
    
    app.config.update(
        DEBUG=True,
        SECRET_KEY='...'
    )

    내장된 고유 설정값들은 직접 확인해 보시면 될 것 같습니다.

  4. 이번 부분은 테스트와 연관이 되어 있는 코드인데요,

    import pytest
    import json
    
    from app import create_app
    from my_settings import TEST_DB_URL
    from company.models import Company, CompanyName, Tag, companies_tags
    
    db = create_app(TEST_DB_URL)[1]
    
    @pytest.fixture
    def client():
        app = create_app(TEST_DB_URL)[0]
        return app.test_client()
    
    def setup_function():
    ....

    테스트 코드는 이렇게 작성이 되었습니다. create_app 함수에서 app, db가 반환이 되는 이유도 db = create_app(TEST_DB_URL)[1]이 코드에서 사용을 해야하기 때문입니다. 다른 방법은 찾지 못했습니다. db 자체를 인식하지 못하더라고요.

    create_app 이라는 함수에서 test_db_url을 매개로 만들어진 db 객체를 받아와야 되는 것 같았습니다. test_client() 또한 이러한 과정에서 만들어진 말 그대로 테스트용 클라이언트로 전송이 되기 때문에 그쪽에 연결된 db와 통신하여 테스트를 진행할 수 있었습니다.

기능 구현 부분

import string

from flask_restx import Resource, Namespace
from flask import request

from company.models import *

Company_Info = Namespace("About Company", description="회사 정보 info")

@Company_Info.route('/search')
class CompanySearchView(Resource):
    def get(self):

        code = request.headers.get("x-wanted-language", "ko")
        company_name = request.args.get("query")
    
        if company_name == "":
            return {"message" : "KEY ERROR"}, 404
        
        company_name_datas = db.session.query(CompanyName).filter(CompanyName.name.like(f"%{company_name}%"), CompanyName.lang==code).all()
        
        if not company_name_datas:
            return {"message" : "COMPANY NOT FOUND"}, 404
        
        result = [
            {"company_name" : company_name.name}
            for company_name in company_name_datas
        ]

        return result, 200

@Company_Info.route('')
class CompanyInfoView(Resource):
    def post(self):
        code = request.headers.get("x-wanted-language", "ko")
        data = request.get_json()

        companies_dict = list(data.get("company_name").items())

        company = db.session.query(CompanyName).filter_by(lang=companies_dict[0][0], name=companies_dict[0][1]).first()

        if company:
                return {"message": "ALREADY EXIST COMPANY"}, 404
        
        company = Company()
        db.session.add(company)
        db.session.commit()

        for lang, name in companies_dict:
            company_name = CompanyName(name=name, lang=lang, company_id=company.id)
            db.session.add(company_name)

        tag_list = []
        for tag in data.get("tags"):
            for lang, name in tag.get("tag_name").items():
                tag_obj = db.session.query(Tag).filter_by(lang=lang, name=name).first()

                if not tag_obj:
                    tag_obj = Tag(name=name, lang=lang)
                    db.session.add(tag_obj)
                    db.session.commit()
                
                if lang == code:
                    tag_list.append(name)

                companies_tags_query = companies_tags.insert().values(
                    company_id = company.id,
                    tag_id = tag_obj.id
                )
                db.session.execute(companies_tags_query)

        db.session.commit()

        return {"company_name": data.get("company_name").get(code), "tags": tag_list}, 201

@Company_Info.route('/<company_name>')
class CompanyDetailView(Resource):
    def get(self, company_name):
        code = request.headers.get("x-wanted-language")

        company = CompanyName.query.filter_by(name=string.capwords(company_name), lang=code).first()

        if company is None:
            return {"message": "NOT FOUND COMPANY"}, 404

        tag_list = []
        for com_tag in db.session.query(companies_tags).filter_by(company_id=company.company_id):
            for tag in Tag.query.filter_by(id=com_tag.tag_id):
                if tag.lang == code:
                    tag_list.append(tag.name)

        return {"company_name": string.capwords(company_name), "tags": tag_list}, 200
  1. 중복된 데이터가 뽑히더라도 기존 다대다 모델은 너무 복잡했기 때문에 데이터베이스 상의 중복 문제는 뒤로 하고 쉽게 이해할 수 있는 모델링을 진행했습니다.
  2. flask에서는 blueprint라는 url 관리 라이브러리가 존재합니다. flask-restx에서는 namespace가 같은 역할을 해 이를 이용해 보는 것으로 리팩토링을 하였습니다.

기초적인 SQLAlchemy 사용법 익히기

SELECT

>>> db.session.query(CompanyName).all()
[<CompanyName 4>, <CompanyName 5>]

SELECT company_names.id AS company_names_id
FROM company_name
  • count 함수를 사용.
>>> db.session.query(func.count(CompanyName.id)).all()
[(18,)]

SELECT count(company_names.id) AS count_1
FROM company_names
  • 가져오기
>>> tag = db.session.query(Tag).filter(Tag.name=='간식').first()

WHEHE

>>> db.session.query(CompanyName).filter(CompanyName.name=='원티드랩').all()
[<CompanyName 4>]

SELECT company_names.id AS company_names_id, company_names.name AS company_names_name, company_names.lang AS company_names_lang, company_names.company_id AS company_names_company_id
FROM company_names
WHERE company_names.name = %(name_1)s

>>> db.session.query(CompanyName).filter(or_(CompanyName.name=='원티드랩', CompanyName.lang=='tw')).all()
[<CompanyName 4>, <CompanyName 5>, <CompanyName 8>, <CompanyName 11>, <CompanyName 14>, <CompanyName 17>, <CompanyName 20>]

SELECT company_names.id AS company_names_id, company_names.name AS company_names_name, company_names.lang AS company_names_lang, company_names.company_id AS company_names_company_id
FROM company_names
WHERE company_names.name = %(name_1)s OR company_names.lang = %(lang_1)s

UPDATE

>>> tag = db.session.query(Tag).filter(Tag.name == '간식').update({'Tag.name' : '수정'});

INSERT

>>> tag = Tag(name='간식', lang='ko')
>>> db.session.add(tag)
>>> db.session.commit()

ORDER_BY

>>> db.session.query(Tag).filter(Tag.name=='간식').order_by(Tag.created_at)

INNER_JOIN

>>> db.session.query(CompanyName, Tag).filter(CompanyName.id, Tag.id).all()
<console>:1: SAWarning: SELECT statement has a cartesian product between FROM element(s) "company_names" and FROM element "tags".  Apply join condition(s) between each element to resolve.

OUTER_JOIN

>>> db.session.query(CompanyName).outerjoin(Tag, CompanyName.id == Tag.id).all()
[<CompanyName 4>, <CompanyName 5>, <CompanyName 6>, <CompanyName 7>, <CompanyName 8>, <CompanyName 9>, <CompanyName 10>, <CompanyName 11>, <CompanyName 12>, <CompanyName 13>, <CompanyName 14>, <CompanyName 15>, <CompanyName 16>, <CompanyName 17>, <CompanyName 18>, <CompanyName 19>, <CompanyName 20>, <CompanyName 21>]

GROUP_BY

>>> db.session.query(CompanyName).group_by(CompanyName.id).all()
[<CompanyName 4>, <CompanyName 5>, <CompanyName 6>, <CompanyName 7>, <CompanyName 8>, <CompanyName 9>, <CompanyName 10>, <CompanyName 11>, <CompanyName 12>, <CompanyName 13>, <CompanyName 14>, <CompanyName 15>, <CompanyName 16>, <CompanyName 17>, <CompanyName 18>, <CompanyName 19>, <CompanyName 20>, <CompanyName 21>]

pytest 적용하기

이번 리팩토링에서 가장 어려웠었던 부분입니다. 함께 해주신 팀원분의 많은 도움을 받아 완성할 수 있었습니다!

import pytest
import json

from app import create_app
from my_settings import TEST_DB_URL
from company.models import Company, CompanyName, Tag, companies_tags

db = create_app(TEST_DB_URL)[1]

@pytest.fixture
def client():
    app = create_app(TEST_DB_URL)[0]
    return app.test_client()

def setup_function():
    datas = [
        {"name": "라인", "lang": "ko"}, 
        {"name": "라인 프레쉬", "lang": "ko"}
    ]
    tags = ["태그_1", "태그_8", "태그_15"]
    
    for data in datas:
        company = Company()
        db.session.add(company)
        db.session.commit()

        company_name = CompanyName(name=data["name"], lang=data["lang"], company_id=company.id)
        db.session.add(company_name)

        for tag in tags:
            tag_obj = Tag.query.filter_by(name=tag, lang=data["lang"]).first()
            if not tag_obj:
                tag_obj = Tag(name=tag, lang=data["lang"])
                db.session.add(tag_obj)
                db.session.commit()

            companies_tags_query = companies_tags.insert().values(
                    company_id = company.id,
                    tag_id = tag_obj.id
                )
            db.session.execute(companies_tags_query)

        db.session.commit()

def teardown_function():
    [db.session.delete(company) for company in CompanyName.query.all()]
    db.session.commit()
    [company_tag.tags.clear() for company_tag in Company.query.all()]
    db.session.commit()
    [db.session.delete(tag) for tag in Tag.query.all()]
    db.session.commit()
    [db.session.delete(company) for company in Company.query.all()]
    db.session.commit()

def test_company_name_autocomplete(client):

    resp = client.get('/companies/search', query_string=dict(query="라인"), headers=[("x-wanted-language", "ko")])
    searched_companies = json.loads(resp.data.decode("utf-8"))

    assert resp.status_code == 200
    assert searched_companies == [
        {"company_name": "라인"},
        {"company_name": "라인 프레쉬"},
    ]
  1. query string 처리에 대하여

    How can I fake request.POST and GET params for unit testing in Flask?

    Testing WSGI Applications

    위 두개 글을 보고 참고하여 해결하였습니다. 특별하게 query string 처리는 ?기호로 처리하는 것이 아니라 query_string=dict(query="라인") 이런 식으로 처리를 해야 하더라고요.

새롭게 알게된 점

새롭게 알게 된 점을 두서없이 정리해 봅니다.

SQL LIKE 사용하기

mysql> SELECT * from company_names WHERE name LIKE "원%";
+----+------------------------+------+------------+
| id | name                   | lang | company_id |
+----+------------------------+------+------------+
|  1 | 원티드 주식회사        | ko   |          1 |
|  4 | 원티드                 | ko   |          2 |
+----+------------------------+------+------------+
2 rows in set (0.01 sec)

mysql> SELECT * from company_names WHERE name LIKE "%원";
Empty set (0.00 sec)

mysql> SELECT * from company_names WHERE name LIKE "%원티드%";
+----+------------------------+------+------------+
| id | name                   | lang | company_id |
+----+------------------------+------+------------+
|  1 | 원티드 주식회사        | ko   |          1 |
|  4 | 원티드                 | ko   |          2 |
+----+------------------------+------+------------+
2 rows in set (0.01 sec)

flask에서 post로 넘어온 JSON 형식 데이터 읽기

{
	"user_id"   : "test01",
	"user_name" : "테스트01"
}

위와 같은 데이터를 POST 방식으로 body에 실어서 보내면 flask에서는 request.get_json()을 이용해 파이썬 데이터 형식으로 변환해 가져올 수 있습니다.

@user_bp.route('/create', methods=['POST'])
def create():
    print(request.is_json)
    params = request.get_json()
    print(params['user_id'])
    return 'ok'
True
{'user_id': 'test01', 'user_name': '테스트01'}
test01

해당 API 를 호출하는 클라이언트에서는 Content-Typeapplication/json 로 해서 호출해야 합니다.

request.get.json('data')

# value 값만 출력된다. 
data = request.get_json()

# 전체 다 나온다. 

아쉬운 점

  1. DB API, ORM 라이브러리 등을 더 자세하게 파악해야 한다는 것을 깨달았습니다. 라이브러리 마다 이렇게 상이한 결과를 도출할지 전혀 알지 못해 생각보다 많은 시간을 쏟아부어야 했습니다. 초반에 잘 파악하는 것이 중요한 것 같습니다.
  2. SQL을 사용할 수 있도록 create_engine을 사용해 보지 못한 것이 아쉽습니다. 다음에는 create_engine을 이용해 조금 더 직관적이고 쉬운 코드를 짤 수 있도록 해야겠습니다.

Reference

도움 받았었던 블로그들을 정리해 두었습니다.

https://fun25.co.kr/blog/python-flask-post-request-get-json/?category=002

flask SQLAlchemy ORM 사용해보기

쿼리 스트링에 대한 내용 참고

Path Parameter, Query String

pytest 관련 내용 참고

Introduction Into Contexts

flask 어플리케이션 팩토리

플라스크 프로젝트 구조와 애플리케이션 팩토리

플라스크 공식문서

profile
지뢰찾기 개발자

0개의 댓글