from flask import Flask
from flask_restful import Api
from flask_migrate import Migrate
from company.models import db
from company.controller import CompanyList, GetDetailCompany, CompanyCreateView
from config import DB_URL
app = Flask(__name__)
api = Api(app)
app.config["SQLALCHEMY_DATABASE_URI"] = DB_URL
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["SQLALCHEMY_COMMIT_ON_TEARDOWN"] = True
db.init_app(app)
db.app = app
db.create_all()
migrate = Migrate(app, db)
api.add_resource(CompanyList, "/search")
api.add_resource(GetDetailCompany, "/companies/<comp_name>")
api.add_resource(CompanyCreateView, '/companies')
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=False)
from flask import Flask
from flask_restx import Api
from flask_migrate import Migrate
from company import config
from company.controller import Company_Info
from company.models import db
migrate = Migrate()
def create_app(test_db_url=None):
app = Flask(__name__)
if test_db_url:
app.config["SQLALCHEMY_DATABASE_URI"] = test_db_url
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
else:
app.config.from_object(config)
db.init_app(app)
db.app = app
db.create_all()
migrate.init_app(app, db)
api = Api(app, version='1.0', title='API 문서', description='Swagger 문서', doc="/api-docs")
api.add_namespace(Company_Info, "/companies")
return app, db
if __name__ == '__main__':
app = create_app()[0]
app.run(host="0.0.0.0", port=6000, debug=True)
flask-restx
의 Flask(__name__)
사용법은 전역에 선언하는 것으로 검색 결과가 많이 나오기도 했고, flask-restx
가 꾸준하게 업데이트가 되는 라이브러리로 알고 진행을 하였습니다. 하지만 pytest
진행 과정 중에서 test_client
를 생성해 요청을 보내야 하는데 그 부분에서 애를 먹다가 팩토리 함수라는 개념을 알게 되어 전역 환경에 선언했던 app = Flask(__name__)
을 함수로 만들어 반환해주는 create_app
함수를 만들어 처리하였습니다.
테스트 코드를 위한 부분
def create_app(test_db_url=None):
app = Flask(__name__)
if test_db_url:
app.config["SQLALCHEMY_DATABASE_URI"] = test_db_url
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
else:
app.config.from_object(config)
create_app()
함수가 실행될 때 만약에 test_db_url
이라는 매개변수가 들어간다면, DATABASE_URI
는 테스트 데이터베이스가 됩니다. 이를 위해 따로 테스트 데이터베이스를 만들어주어야 합니다.
만약 그렇지 않다면 from_object(config)
라는 문장이 실행되어 config 파일
에 있는 정보를 읽어옵니다.
from_object 옵션에 대해 flask-docs을 확인해 보았습니다.
대부분의 어플리케이션은 하나 이상의 설정(구성)이 필요 하다. 적어도 운영 환경과 개발환경은 독립된 설정값을 가지고 있어야만 한다. 이것을 다루는 가장 쉬운 방법은 버전 관리를 통하여 항상 로드되는 기본 설정값을 사용하는 것이다. 그리고 독립된 설정을값들을 필요에 따라 위에서 언급했던 방식으로 덮어쓰기한다:
app = Flask(__name__)
app.config.from_object('yourapplication.default_settings')
app.config.from_envvar('YOURAPPLICATION_SETTINGS')
이렇게 설명이 되어 있었는데요, 그냥 object로 불러오기만 해도 설정값이 잘 적용되는 것 같았습니다!
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
옵션의 경우 추가적인 메모리를 필요로 해 꺼두는 것을 추천한다고 합니다.
위의 flask 객체에 대한 설정값을 조금 더 자세하게 알아보겠습니다. flask-docs을 참고하였습니다.
[config](https://flask-docs-kr.readthedocs.io/ko/latest/ko/api.html#flask.Flask.config)
은 실제로는 dictionary 의 서브클래스이며, 다른 dictionary 처럼 다음과 같이 수정될 수 있다:
app = Flask(__name__)
app.config['DEBUG'] = True
확정된 설정값들은 또한
[Flask](https://flask-docs-kr.readthedocs.io/ko/latest/ko/api.html#flask.Flask)
객체로 전달될 수 있으며, 그 객체를 통해 설정값들을 읽거나 쓸수 있다. 한번에 다수의 키(key)들을 업데이트 하기 위해서는[dict.update()](http://docs.python.org/dev/library/stdtypes.html#dict.update)
함수를 사용 할 수 있다.
app.debug = True
app.config.update(
DEBUG=True,
SECRET_KEY='...'
)
내장된 고유 설정값들은 직접 확인해 보시면 될 것 같습니다.
이번 부분은 테스트와 연관이 되어 있는 코드인데요,
import pytest
import json
from app import create_app
from my_settings import TEST_DB_URL
from company.models import Company, CompanyName, Tag, companies_tags
db = create_app(TEST_DB_URL)[1]
@pytest.fixture
def client():
app = create_app(TEST_DB_URL)[0]
return app.test_client()
def setup_function():
....
테스트 코드는 이렇게 작성이 되었습니다. create_app
함수에서 app, db
가 반환이 되는 이유도 db = create_app(TEST_DB_URL)[1]
이 코드에서 사용을 해야하기 때문입니다. 다른 방법은 찾지 못했습니다. db 자체를 인식하지 못하더라고요.
create_app
이라는 함수에서 test_db_url
을 매개로 만들어진 db 객체
를 받아와야 되는 것 같았습니다. test_client()
또한 이러한 과정에서 만들어진 말 그대로 테스트용 클라이언트로 전송이 되기 때문에 그쪽에 연결된 db와 통신하여 테스트를 진행할 수 있었습니다.
import string
from flask_restx import Resource, Namespace
from flask import request
from company.models import *
Company_Info = Namespace("About Company", description="회사 정보 info")
@Company_Info.route('/search')
class CompanySearchView(Resource):
def get(self):
code = request.headers.get("x-wanted-language", "ko")
company_name = request.args.get("query")
if company_name == "":
return {"message" : "KEY ERROR"}, 404
company_name_datas = db.session.query(CompanyName).filter(CompanyName.name.like(f"%{company_name}%"), CompanyName.lang==code).all()
if not company_name_datas:
return {"message" : "COMPANY NOT FOUND"}, 404
result = [
{"company_name" : company_name.name}
for company_name in company_name_datas
]
return result, 200
@Company_Info.route('')
class CompanyInfoView(Resource):
def post(self):
code = request.headers.get("x-wanted-language", "ko")
data = request.get_json()
companies_dict = list(data.get("company_name").items())
company = db.session.query(CompanyName).filter_by(lang=companies_dict[0][0], name=companies_dict[0][1]).first()
if company:
return {"message": "ALREADY EXIST COMPANY"}, 404
company = Company()
db.session.add(company)
db.session.commit()
for lang, name in companies_dict:
company_name = CompanyName(name=name, lang=lang, company_id=company.id)
db.session.add(company_name)
tag_list = []
for tag in data.get("tags"):
for lang, name in tag.get("tag_name").items():
tag_obj = db.session.query(Tag).filter_by(lang=lang, name=name).first()
if not tag_obj:
tag_obj = Tag(name=name, lang=lang)
db.session.add(tag_obj)
db.session.commit()
if lang == code:
tag_list.append(name)
companies_tags_query = companies_tags.insert().values(
company_id = company.id,
tag_id = tag_obj.id
)
db.session.execute(companies_tags_query)
db.session.commit()
return {"company_name": data.get("company_name").get(code), "tags": tag_list}, 201
@Company_Info.route('/<company_name>')
class CompanyDetailView(Resource):
def get(self, company_name):
code = request.headers.get("x-wanted-language")
company = CompanyName.query.filter_by(name=string.capwords(company_name), lang=code).first()
if company is None:
return {"message": "NOT FOUND COMPANY"}, 404
tag_list = []
for com_tag in db.session.query(companies_tags).filter_by(company_id=company.company_id):
for tag in Tag.query.filter_by(id=com_tag.tag_id):
if tag.lang == code:
tag_list.append(tag.name)
return {"company_name": string.capwords(company_name), "tags": tag_list}, 200
>>> db.session.query(CompanyName).all()
[<CompanyName 4>, <CompanyName 5>]
SELECT company_names.id AS company_names_id
FROM company_name
>>> db.session.query(func.count(CompanyName.id)).all()
[(18,)]
SELECT count(company_names.id) AS count_1
FROM company_names
>>> tag = db.session.query(Tag).filter(Tag.name=='간식').first()
>>> db.session.query(CompanyName).filter(CompanyName.name=='원티드랩').all()
[<CompanyName 4>]
SELECT company_names.id AS company_names_id, company_names.name AS company_names_name, company_names.lang AS company_names_lang, company_names.company_id AS company_names_company_id
FROM company_names
WHERE company_names.name = %(name_1)s
>>> db.session.query(CompanyName).filter(or_(CompanyName.name=='원티드랩', CompanyName.lang=='tw')).all()
[<CompanyName 4>, <CompanyName 5>, <CompanyName 8>, <CompanyName 11>, <CompanyName 14>, <CompanyName 17>, <CompanyName 20>]
SELECT company_names.id AS company_names_id, company_names.name AS company_names_name, company_names.lang AS company_names_lang, company_names.company_id AS company_names_company_id
FROM company_names
WHERE company_names.name = %(name_1)s OR company_names.lang = %(lang_1)s
>>> tag = db.session.query(Tag).filter(Tag.name == '간식').update({'Tag.name' : '수정'});
>>> tag = Tag(name='간식', lang='ko')
>>> db.session.add(tag)
>>> db.session.commit()
>>> db.session.query(Tag).filter(Tag.name=='간식').order_by(Tag.created_at)
>>> db.session.query(CompanyName, Tag).filter(CompanyName.id, Tag.id).all()
<console>:1: SAWarning: SELECT statement has a cartesian product between FROM element(s) "company_names" and FROM element "tags". Apply join condition(s) between each element to resolve.
>>> db.session.query(CompanyName).outerjoin(Tag, CompanyName.id == Tag.id).all()
[<CompanyName 4>, <CompanyName 5>, <CompanyName 6>, <CompanyName 7>, <CompanyName 8>, <CompanyName 9>, <CompanyName 10>, <CompanyName 11>, <CompanyName 12>, <CompanyName 13>, <CompanyName 14>, <CompanyName 15>, <CompanyName 16>, <CompanyName 17>, <CompanyName 18>, <CompanyName 19>, <CompanyName 20>, <CompanyName 21>]
>>> db.session.query(CompanyName).group_by(CompanyName.id).all()
[<CompanyName 4>, <CompanyName 5>, <CompanyName 6>, <CompanyName 7>, <CompanyName 8>, <CompanyName 9>, <CompanyName 10>, <CompanyName 11>, <CompanyName 12>, <CompanyName 13>, <CompanyName 14>, <CompanyName 15>, <CompanyName 16>, <CompanyName 17>, <CompanyName 18>, <CompanyName 19>, <CompanyName 20>, <CompanyName 21>]
이번 리팩토링에서 가장 어려웠었던 부분입니다. 함께 해주신 팀원분의 많은 도움을 받아 완성할 수 있었습니다!
import pytest
import json
from app import create_app
from my_settings import TEST_DB_URL
from company.models import Company, CompanyName, Tag, companies_tags
db = create_app(TEST_DB_URL)[1]
@pytest.fixture
def client():
app = create_app(TEST_DB_URL)[0]
return app.test_client()
def setup_function():
datas = [
{"name": "라인", "lang": "ko"},
{"name": "라인 프레쉬", "lang": "ko"}
]
tags = ["태그_1", "태그_8", "태그_15"]
for data in datas:
company = Company()
db.session.add(company)
db.session.commit()
company_name = CompanyName(name=data["name"], lang=data["lang"], company_id=company.id)
db.session.add(company_name)
for tag in tags:
tag_obj = Tag.query.filter_by(name=tag, lang=data["lang"]).first()
if not tag_obj:
tag_obj = Tag(name=tag, lang=data["lang"])
db.session.add(tag_obj)
db.session.commit()
companies_tags_query = companies_tags.insert().values(
company_id = company.id,
tag_id = tag_obj.id
)
db.session.execute(companies_tags_query)
db.session.commit()
def teardown_function():
[db.session.delete(company) for company in CompanyName.query.all()]
db.session.commit()
[company_tag.tags.clear() for company_tag in Company.query.all()]
db.session.commit()
[db.session.delete(tag) for tag in Tag.query.all()]
db.session.commit()
[db.session.delete(company) for company in Company.query.all()]
db.session.commit()
def test_company_name_autocomplete(client):
resp = client.get('/companies/search', query_string=dict(query="라인"), headers=[("x-wanted-language", "ko")])
searched_companies = json.loads(resp.data.decode("utf-8"))
assert resp.status_code == 200
assert searched_companies == [
{"company_name": "라인"},
{"company_name": "라인 프레쉬"},
]
query string 처리에 대하여
How can I fake request.POST and GET params for unit testing in Flask?
위 두개 글을 보고 참고하여 해결하였습니다. 특별하게 query string 처리는 ?기호로 처리하는 것이 아니라 query_string=dict(query="라인")
이런 식으로 처리를 해야 하더라고요.
새롭게 알게 된 점을 두서없이 정리해 봅니다.
mysql> SELECT * from company_names WHERE name LIKE "원%";
+----+------------------------+------+------------+
| id | name | lang | company_id |
+----+------------------------+------+------------+
| 1 | 원티드 주식회사 | ko | 1 |
| 4 | 원티드 | ko | 2 |
+----+------------------------+------+------------+
2 rows in set (0.01 sec)
mysql> SELECT * from company_names WHERE name LIKE "%원";
Empty set (0.00 sec)
mysql> SELECT * from company_names WHERE name LIKE "%원티드%";
+----+------------------------+------+------------+
| id | name | lang | company_id |
+----+------------------------+------+------------+
| 1 | 원티드 주식회사 | ko | 1 |
| 4 | 원티드 | ko | 2 |
+----+------------------------+------+------------+
2 rows in set (0.01 sec)
{
"user_id" : "test01",
"user_name" : "테스트01"
}
위와 같은 데이터를 POST 방식
으로 body
에 실어서 보내면 flask
에서는 request.get_json()
을 이용해 파이썬 데이터 형식으로 변환해 가져올 수 있습니다.
@user_bp.route('/create', methods=['POST'])
def create():
print(request.is_json)
params = request.get_json()
print(params['user_id'])
return 'ok'
True
{'user_id': 'test01', 'user_name': '테스트01'}
test01
해당 API 를 호출하는 클라이언트에서는 Content-Type
을 application/json
로 해서 호출해야 합니다.
request.get.json('data')
# value 값만 출력된다.
data = request.get_json()
# 전체 다 나온다.
도움 받았었던 블로그들을 정리해 두었습니다.
https://fun25.co.kr/blog/python-flask-post-request-get-json/?category=002