[DRF]Django REST Framework에서 PyTest로 테스트하기(2)

김동욱·2024년 3월 28일
post-thumbnail

테스트 개요와 관련하여 정리한 이전 글에서 이어서 작성해 보겠다.

테스트를 구성하는 방법

우선, 이전에 언급했듯이, 테스트는 pytest.ini나 setup.cfg 파일에 명시된 패턴을 따라 파일 안에 있어야 한다.

파일 내의 테스트를 정리하는 방법으로는, 테스트되는 코드 단위마다(예를 들어, Django의 APIView) 카멜 케이스(Camel Case)를 사용하여 Test로 시작하는 클래스를 생성하고, 그 안에 해당 단위에 대한 테스트들을 생성한다(예를 들어, 뷰에서 허용하는 모든 메소드에 대한 테스트).

테스트 함수 아래 테스트를 구성하는 데에는 “AMAA” 기준을 사용하는 것을 추천한다. 이는 “AAA” 기준의 사용자 정의 버전이다. 즉, 테스트는 다음 순서를 따라야 한다.

  1. Arrange (준비): 테스트에 필요한 모든 것을 설정한다.
  2. Mock (모의): 테스트를 격리하기 위해 필요한 모든 것을 Mocking한다.
  3. Act (실행): 코드 단위를 실행한다.
  4. Assert (확인): 결과가 정확히 예상대로 나왔는지 확인한다.

따라서 테스트 구조는 다음과 같아야 한다.

# tests/test_app/app/test_some_part.py

...
# inside test_something.py

class TestUnitName:
    def test_<functionality_1>(self):
        # Arrange

        # Mock

        # Act

        # Assert
...

테스트 예제들

예시로, Transaction과 Currency 모델을 구성하고 이 모델을 중심으로 예제를 구축했다.

# inside apps/app/models.py

import string

from django.db import models
from django.utils import timezone
from hashid_field import HashidAutoField

from apps.transaction.utils import create_payment_intent, PaymentStatuses


class Currency(models.Model):
    """Currency model"""
    name    = models.CharField(max_length=120, null=False, blank=False, unique=True)
    code    = models.CharField(max_length=3, null=False, blank=False, unique=True)
    symbol  = models.CharField(max_length=5, null=False, blank=False, default='$')

    def __str__(self) -> str:
        return self.code


class Transaction(models.Model):
    """Transaction model."""
    id                  = HashidAutoField(primary_key=True, min_length=8, alphabet=string.printable.replace('/', ''))
    name                = models.CharField(max_length=50, null=False, blank=False)
    email               = models.EmailField(max_length=50, null=False, blank=False)
    creation_date       = models.DateTimeField(auto_now_add=True, null=False, blank=False)
    currency            = models.ForeignKey(Currency, null=False, blank=False, default=1, on_delete=models.PROTECT)
    payment_status      = models.CharField(choices=PaymentStatuses.choices, default=PaymentStatuses.WAI, max_length=21)
    payment_intent_id   = models.CharField(max_length=100, null=True, blank=False, default=None)
    message             = models.TextField(null=True, blank=True)

    @property
    def link(self):
        """
        Link to a payment form for the transaction
        """
        return settings.ALLOWED_HOSTS[0] + f'/payment/{str(self.id)}'

그리고 tests/test_app/conftest.py 내부에서 우리는 팩토리를 fixture로 설정하여 나중에 테스트 함수에서 매개변수로 접근할 수 있게 할 것이다. 항상 사용자가 모든 필드를 채운 모델 인스턴스가 필요한 것은 아니다. 때때로 우리는 백엔드에서 자동으로 필드를 채우고 싶을 수 있다. 그런 경우에는 특정 필드가 채워진 사용자 정의 팩토리를 만들 수 있다.

def utbb():
    def unfilled_transaction_bakery_batch(n):
        utbb = baker.make(
            'transaction.Transaction',
            amount_in_cents=1032000, # --> Passes min. payload restriction in every currency
            _fill_optional=[
                'name',
                'email',
                'currency',
                'message'
            ],
            _quantity=n
        )
        return utbb
    return unfilled_transaction_bakery_batch

@pytest.fixture
def ftbb():
    def filled_transaction_bakery_batch(n):
        utbb = baker.make(
            'transaction.Transaction',
            amount_in_cents=1032000, # --> Passes min. payload restriction in every currency
            _quantity=n
        )
        return utbb
    return filled_transaction_bakery_batch

@pytest.fixture
def ftb():
    def filled_transaction_bakery():
        utbb = baker.make(
            'transaction.Transaction',
            amount_in_cents=1032000, # --> Passes min. payload restriction in every currency
            currency=baker.make('transaction.Currency')   
        )
        return utbb
    return filled_transaction_bakery

E2E Tests

모델과 해당 팩토리가 이미 생성되었다면, 우리 앱에서 원하는 흐름의 엔드 투 엔드 테스트를 수행하는 것이 좋다. 이것은 tutor로서뿐만 아니라 sanity check로서도 유용하다. 이 테스트들은 우리가 만들자마자 통과되지 않을 것이다. 이것은, 다시 말하지만, 프로그래머가 개발 과정에서 각 엔드포인트가 각 입력에 대해 반환해야 할 것이 무엇인지에 대한 유연한 테스트이다.

우리가 작업 중인 기능의 복잡성이나 제품 소유자에 의한 변경에 따라, 이 테스트들을 과정 중에 완전히 변경할 수도 있다.

이러한 엔드 투 엔드 테스트를 위해, 많은 사례에 대한 예제를 가지기 위해, 우리는 두 모델 모두에 대해 전체 CRUD 기능을 모든 여섯 가지 HTTP 메소드를 사용하여 다음과 같이 구현하고자 한다고 가정한다.

GET     api/transactions        List all transaction objects
POST    api/transactions        Create a transaction object
GET     api/transactions        Retrieve a transaction object
PUT     api/transactions/hash   Update a transaction object
PATCH   api/transactions/hash   Update a field of a transaction object
DELETE  api/transactions/hash   Delete a transaction object  

이를 위해 다음을 수행한다.

DRF의 API Client 개체에 대한 fixture를 만들고 이를 api_client와 같이 명명하여 엔드포인트에서 직접 테스트를 시작한다.

# conftest.py
@pytest.fixture
def api_client():
    return APIClient()

이 fixture는 이 앱과 관련이 있을 뿐만 아니라 모든 앱 위에 conftest.py 파일로 정의하여 모든 앱 간에 공유할 수 있도록 하는 것이 좋다.

....
├── tests
│   ├── __init__.py
│   ├── conftest.py <--
│   ├── test_app1
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── factories.py
│   │   ├── e2e_tests.py
│   │   ├── test_models.py
│   │   ├── test_signals.py
│   │   ├── test_serializers.py
│   │   ├── test_utils.py
│   │   ├── test_views.py
│   │   └── test_urls.py
│   │
│   └── ...
└── ...

이제 모든 엔드포인트를 테스트해보자.

from model_bakery import baker
import factory
import json
import pytest

from apps.transaction.models import Transaction, Currency


pytestmark = pytest.mark.django_db

class TestCurrencyEndpoints:

    endpoint = '/api/currencies/'

    def test_list(self, api_client):
        baker.make(Currency, _quantity=3)

        response = api_client().get(
            self.endpoint
        )

        assert response.status_code == 200
        assert len(json.loads(response.content)) == 3

    def test_create(self, api_client):
        currency = baker.prepare(Currency) 
        expected_json = {
            'name': currency.name,
            'code': currency.code,
            'symbol': currency.symbol
        }

        response = api_client().post(
            self.endpoint,
            data=expected_json,
            format='json'
        )

        assert response.status_code == 201
        assert json.loads(response.content) == expected_json

    def test_retrieve(self, api_client):
        currency = baker.make(Currency)
        expected_json = {
            'name': currency.name,
            'code': currency.code,
            'symbol': currency.symbol
        }
        url = f'{self.endpoint}{currency.id}/'

        response = api_client().get(url)

        assert response.status_code == 200
        assert json.loads(response.content) == expected_json

    def test_update(self, rf, api_client):
        old_currency = baker.make(Currency)
        new_currency = baker.prepare(Currency)
        currency_dict = {
            'code': new_currency.code,
            'name': new_currency.name,
            'symbol': new_currency.symbol
        } 

        url = f'{self.endpoint}{old_currency.id}/'

        response = api_client().put(
            url,
            currency_dict,
            format='json'
        )

        assert response.status_code == 200
        assert json.loads(response.content) == currency_dict

    @pytest.mark.parametrize('field',[
        ('code'),
        ('name'),
        ('symbol'),
    ])
    def test_partial_update(self, mocker, rf, field, api_client):
        currency = baker.make(Currency)
        currency_dict = {
            'code': currency.code,
            'name': currency.name,
            'symbol': currency.symbol
        } 
        valid_field = currency_dict[field]
        url = f'{self.endpoint}{currency.id}/'

        response = api_client().patch(
            url,
            {field: valid_field},
            format='json'
        )

        assert response.status_code == 200
        assert json.loads(response.content)[field] == valid_field

    def test_delete(self, mocker, api_client):
        currency = baker.make(Currency)
        url = f'{self.endpoint}{currency.id}/'

        response = api_client().delete(url)

        assert response.status_code == 204
        assert Currency.objects.all().count() == 0

class TestTransactionEndpoints:

    endpoint = '/api/transactions/'

    def test_list(self, api_client, utbb):
        client = api_client()
        utbb(3)
        url = self.endpoint
        response = client.get(url)

        assert response.status_code == 200
        assert len(json.loads(response.content)) == 3

    def test_create(self, api_client, utbb):
        client = api_client()
        t = utbb(1)[0]
        valid_data_dict = {
            'amount_in_cents': t.amount_in_cents,
            'currency': t.currency.code,
            'name': t.name,
            'email': t.email,
            'message': t.message
        }

        url = self.endpoint

        response = client.post(
            url,
            valid_data_dict,
            format='json'
        )

        assert response.status_code == 201
        assert json.loads(response.content) == valid_data_dict
        assert Transaction.objects.last().link

    def test_retrieve(self, api_client, ftb):
        t = ftb()
        t = Transaction.objects.last()
        expected_json = t.__dict__
        expected_json['link'] = t.link
        expected_json['currency'] = t.currency.code
        expected_json['creation_date'] = expected_json['creation_date'].strftime(
            '%Y-%m-%dT%H:%M:%S.%fZ'
        )
        expected_json.pop('_state')
        expected_json.pop('currency_id')            
        url = f'{self.endpoint}{t.id}/'

        response = api_client().get(url)

        assert response.status_code == 200 or response.status_code == 301
        assert json.loads(response.content) == expected_json

    def test_update(self, api_client, utbb):
        old_transaction = utbb(1)[0]
        t = utbb(1)[0]
        expected_json = t.__dict__
        expected_json['id'] = old_transaction.id.hashid
        expected_json['currency'] = old_transaction.currency.code
        expected_json['link'] = Transaction.objects.first().link
        expected_json['creation_date'] = old_transaction.creation_date.strftime(
            '%Y-%m-%dT%H:%M:%S.%fZ'
        )
        expected_json.pop('_state')
        expected_json.pop('currency_id')    

        url = f'{self.endpoint}{old_transaction.id}/'

        response = api_client().put(
            url,
            data=expected_json,
            format='json'            
        )

        assert response.status_code == 200 or response.status_code == 301
        assert json.loads(response.content) == expected_json

    @pytest.mark.parametrize('field',[
        ('name'),
        ('billing_name'),
        ('billing_email'),
        ('email'),
        ('amount_in_cents'),
        ('message'),
    ])
    def test_partial_update(self, api_client, field, utbb):
        utbb(2)
        old_transaction = Transaction.objects.first()
        new_transaction = Transaction.objects.last()
        valid_field = {
            field: new_transaction.__dict__[field],
        }
        url = f'{self.endpoint}{old_transaction.id}/'

        response = api_client().patch(
            path=url,
            data=valid_field,
            format='json',
        )

        assert response.status_code == 200 or response.status_code == 301 
        try:
            assert json.loads(response.content)[field] == valid_field[field]
        except json.decoder.JSONDecodeError as e:
            pass

    def test_delete(self, api_client, utbb):
        transaction = utbb(1)[0]
        url = f'{self.endpoint}{transaction.id}/'

        response = api_client().delete(
            url
        )

        assert response.status_code == 204 or response.status_code == 301 

우리 앱의 엔드포인트에 대한 예상 출력에 대한 테스트를 마련하고 나면, 모델부터 나머지 앱을 구축해 나갈 수 있다.

Utils

유틸은 우리 코드 전반에 퍼져 있는 helper 함수이므로, 어떤 순서로든 이들을 구축하고 해당하는 테스트를 만들 수 있다.

우리가 만들 첫 번째 유틸은 fill_transaction 함수이다. 이 함수는 Transaction 모델의 인스턴스가 주어지면, 사용자가 입력하지 않아야 할 필드를 채운다.

백엔드에서 채울 수 있는 필드 중 하나는 payment_intent_id 필드이다. "payment intent"는 Stripe(결제 서비스)가 예상되는 거래를 표현하는 방식이며, 그 ID는 간단히 말해 그들의 데이터베이스에서 해당 데이터를 찾는 방법이다.

따라서 Stripe의 파이썬 라이브러리를 사용하여 payment intent ID를 생성하고 검색하는 유틸은 다음과 같을 수 있다.

def fill_transaction(transaction):
    payment_intent_id = stripe.PaymentIntent.create(
        amount=amount,
        currency=currency.code.lower(),
        payment_method_types=['card'],
    ).id

    t = transaction.__class__.objects.filter(id=transaction.id)

    t.update( # We use update not to trigger a save-signal recursion Overflow
        payment_intent_id=payment_intent_id,
    )

이 util에 대한 테스트는 API 호출과 2 db 호출을 mocking해야 한다.

class TestFillTransaction:

    def test_function_code(self, mocker):

        t = FilledTransactionFactory.build()
        pi = PaymentIntentFactory()

        create_pi_mock = mocker.Mock(return_value=pi)
        stripe.PaymentIntent.create = create_pi_mock       
        filter_call_mock = mocker.Mock()
        Transaction.objects.filter = filter_call_mock
        update_call_mock = mocker.Mock()
        filter_call_mock.return_value.update = update_call_mock

        utils.fill_transaction(t)

        filter_call_mock.assert_called_with(id=t.id)
        update_call_mock.assert_called_with(
            payment_intent_id=pi.id,
            stripe_response=pi.last_response.data,
            billing_email=t.email,
            billing_name=t.name,
        ) 

Signals

Signal의 경우 Transaction이 생성될 때 fill_transaction 유틸리티를 실행하는 Signal를 가질 수 있다.

from django.db.models.signals import pre_save
from django.dispatch import receiver

from apps.transaction.models import Transaction
from apps.transaction.utils import fill_transaction


@receiver(pre_save, sender=Transaction)
def transaction_filler(sender, instance, *args, **kwargs):
    """Fills fields"""
    if not instance.id:
        fill_transaction(instance)

이 Signal은 e2e에서 암묵적으로 테스트된다. 이 Signal에 대한 좋은 명시적 단위 테스트는 다음과 같다.

import pytest

from django.db.models.signals import pre_save

from apps.transaction.models import Transaction
from tests.test_transaction.factories import UnfilledTransactionFactory, FilledTransactionFactory


pytestmark = pytest.mark.unit

class TestTransactionFiller:

    def test_pre_save(self, mocker):
        instance = UnfilledTransactionFactory.build()
        mock = mocker.patch(
            'apps.transaction.signals.fill_transaction'
        )

        pre_save.send(Transaction, instance=instance, created=True)

        mock.assert_called_with(instance)

Serializers

우리 앱에서는 Currency 모델에 대해 하나의 시리얼라이저를, Transaction 모델에 대해서는 두 개의 시리얼라이저를 가질 예정이다.

  • 거래 관리자(거래를 생성하고 삭제하는 사람)가 수정할 수 있는 필드를 포함하는 Serializer
  • 거래 관리자 뿐만 아니라 결제를 담당할 사람들에게 보여질 필드를 포함하는 Serializer
from hashid_field.rest import HashidSerializerCharField
from rest_framework import serializers

from django.conf import settings
from django.core.validators import MaxLengthValidator, ProhibitNullCharactersValidator
from rest_framework.validators import ProhibitSurrogateCharactersValidator

from apps.transaction.models import Currency, Transaction


class CurrencySerializer(serializers.ModelSerializer):

    class Meta:
        model = Currency
        fields = ['name', 'code', 'symbol']
        if settings.DEBUG == True:
            extra_kwargs = {
                'name': {
                    'validators': [MaxLengthValidator, ProhibitNullCharactersValidator]
                },
                'code': {
                    'validators': [MaxLengthValidator, ProhibitNullCharactersValidator]
                }
            }

class UnfilledTransactionSerializer(serializers.ModelSerializer):
    currency = serializers.SlugRelatedField(
        slug_field='code',
        queryset=Currency.objects.all(),
    )

    class Meta:
        model = Transaction
        fields = (
            'name',
            'currency',
            'email',
            'amount_in_cents',
            'message'
        )

class FilledTransactionSerializer(serializers.ModelSerializer):
    id = HashidSerializerCharField(source_field='transaction.Transaction.id', read_only=True)
    currency = serializers.StringRelatedField(read_only=True)
    link = serializers.ReadOnlyField()

    class Meta:
        model = Transaction
        fields = '__all__'
        extra_kwargs = {
            """Non editable fields"""
            'id': {'read_only': True},
            'creation_date': {'read_only': True},
            'payment_date': {'read_only': True},
            'amount_in_cents': {'read_only': True},
            'payment_intent_id': {'read_only': True},
            'payment_status': {'read_only': True},

        }

시리얼라이저에 대한 단위 테스트는(관련이 있을 때) 두 가지를 테스트하는 것을 목표로 해야 한다.

  • model 인스턴스를 적절히 serialize할 수 있는지 여부
  • serialize된 유효한 데이터를 모델로 변환할 수 있는지 여부(deserialize)
import pytest
import factory

from rest_framework.fields import CharField

from apps.transaction.api.serializers import CurrencySerializer, UnfilledTransactionSerializer, FilledTransactionSerializer
from tests.test_transaction.factories import CurrencyFactory, UnfilledTransactionFactory, FilledTransactionFactory


class TestCurrencySerializer:

    transaction = UnfilledTransactionFactory.build()

    @pytest.mark.unit
    def test_serialize_model(self):
        currency = CurrencyFactory.build()
        serializer = CurrencySerializer(currency)

        assert serializer.data

    @pytest.mark.unit
    def test_serialized_data(self, mocker):
        valid_serialized_data = factory.build(
            dict,
            FACTORY_CLASS=CurrencyFactory
        )

        serializer = CurrencySerializer(data=valid_serialized_data)

        assert serializer.is_valid()
        assert serializer.errors == {}


class TestUnfilledTransactionSerializer:

    @pytest.mark.unit
    def test_serialize_model(self):
        t = UnfilledTransactionFactory.build()
        expected_serialized_data = {
            'name': t.name,
            'currency': t.currency.code,
            'email': t.email,
            'amount_in_cents': t.amount_in_cents,
            'message': t.message,
        }

        serializer = UnfilledTransactionSerializer(t)

        assert serializer.data == expected_serialized_data

    @pytest.mark.django_db
    def test_serialized_data(self, mocker):
        c = CurrencyFactory()
        t = UnfilledTransactionFactory.build(currency=c)
        valid_serialized_data = {
            'name': t.name,
            'currency': t.currency.code,
            'email': t.email,
            'amount_in_cents': t.amount_in_cents,
            'message': t.message,
        }

        serializer = UnfilledTransactionSerializer(data=valid_serialized_data)

        assert serializer.is_valid(raise_exception=True)
        assert serializer.errors == {}

ml_model_name_max_chars = 134

    @pytest.mark.parametrize("wrong_field", (
        {"name": "a" * (ml_model_name_max_chars + 1)},
        {"tags": "tag outside of array"},
        {"tags": ["--------wrong length tag--------"]},
        {"version": "wronglengthversion"},
        {"is_public": 1},
        {"is_public": "Nope"},
    ))
    def test_deserialize_fails(self, wrong_field: dict):
        transaction_fields = [field.name for field in UnfilledTransaction._meta.get_fields()]
        invalid_serialized_data = {
            k: v for (k, v) in self.transaction.__dict__.items() if k in transaction_fields and k != "id"
        } | wrong_field

        serializer = MLModelSerializer(data=invalid_serialized_data)

        assert not serializer.is_valid()
        assert serializer.errors != {}

class TestFilledTransactionSerializer:

    @pytest.mark.unit
    def test_serialize_model(self, ftd):
        t = FilledTransactionFactory.build()
        expected_serialized_data = ftd(t)

        serializer = FilledTransactionSerializer(t)

        assert serializer.data == expected_serialized_data

    @pytest.mark.unit
    def test_serialized_data(self):
        t = FilledTransactionFactory.build()
        valid_serialized_data = {
            'id': t.id.hashid,
            'name': t.name,
            'currency': t.currency.code,
            'creation_date': t.creation_date.strftime('%Y-%m-%dT%H:%M:%SZ'),
            'payment_date': t.payment_date.strftime('%Y-%m-%dT%H:%M:%SZ'),
            'stripe_response': t.stripe_response,
            'payment_intent_id': t.payment_intent_id,
            'billing_name': t.billing_name,
            'billing_email': t.billing_email,
            'payment_status': t.payment_status,
            'link': t.link,
            'email': t.email,
            'amount_in_cents': t.amount_in_cents,
            'message': t.message,
        }

        serializer = FilledTransactionSerializer(data=valid_serialized_data)

        assert serializer.is_valid(raise_exception=True)
        assert serializer.errors == {}

Viewsets

우리는 모델의 CRUD 작업을 위해 DRF(Django Rest Framework) viewsets을 사용할 것이며, 이를 통해 URL 구성을 테스트할 필요를 생략할 수 있다.

route_lists = [
    transaction_urls.route_list,
]
router = routers.DefaultRouter()
for route_list in route_lists:
    for route in route_list:
        router.register(route[0], route[1])

urlpatterns = [
    path('admin/', admin.site.urls),
    path('api/', include(router.urls)),
]

# in views.py
from rest_framework.viewsets import ModelViewSet
from rest_framework.permissions import IsAuthenticated

from apps.transaction.api.serializers import CurrencySerializer, UnfilledTransactionSerializer, FilledTransactionSerializer
from apps.transaction.models import Currency, Transaction


class CurrencyViewSet(ModelViewSet):
    queryset = Currency.objects.all()
    serializer_class = CurrencySerializer


class TransactionViewset(ModelViewSet):
    """Transaction Viewset"""

    queryset = Transaction.objects.all()
    permission_classes = [IsAuthenticated]

    def get_serializer_class(self):
        if self.action == 'create':
            return UnfilledTransactionSerializer
        else:
            return FilledTransactionSerializer

우리가 테스트할 뷰에서 사용되는 모든 권한을 모의 처리하는 것이 첫 번째 단계이다. 이러한 권한들은 나중에 따로 테스트될 것이다.

from rest_framework.permissions import IsAuthenticated

@pytest.fixture(scope="session", autouse=True)
def mock_views_permissions():

    # little util I use for testing for DRY when patching multiple objects
    patch_perm = lambda perm: mock.patch.multiple(
        perm,
        has_permission=mock.Mock(return_value=True),
        has_object_permission=mock.Mock(return_value=True),
    )
    with ( 
        patch_perm(IsAuthenticated),
        # ...add other permissions you may have below
    ):
        yield

일반적인 클래스 기반 뷰(class-based-views)나 함수 기반 뷰(function-based-views)를 사용할 경우, 뷰 자체에서 테스트를 시작하여 해당 뷰를 트리거하는 URL 구성(urlconf)으로부터 뷰를 격리시켜 테스트한다. 하지만 이 경우 라우터를 사용하고 있기 때문에, API 클라이언트를 사용하여 엔드포인트 자체에서 뷰셋을 테스트하기 시작한다.

import factory
import json
import pytest

from django.urls import reverse
from django_mock_queries.mocks import MockSet
from rest_framework.relations import RelatedField, SlugRelatedField

from apps.transaction.api.serializers import UnfilledTransactionSerializer, CurrencySerializer
from apps.transaction.api.views import CurrencyViewSet, TransactionViewset
from apps.transaction.models import Currency, Transaction
from tests.test_transaction.factories import CurrencyFactory, FilledTransactionFactory, UnfilledTransactionFactory


pytestmark = [pytest.mark.urls('config.urls'), pytest.mark.unit]

class TestCurrencyViewset:

    def test_list(self, mocker, rf):
        # Arrange
        url = reverse('currency-list')
        request = rf.get(url)
        qs = MockSet(
            CurrencyFactory.build(),
            CurrencyFactory.build(),
            CurrencyFactory.build()
        )        
        view = CurrencyViewSet.as_view(
            {'get': 'list'}
        )
        #Mcking
        mocker.patch.object(
            CurrencyViewSet, 'get_queryset', return_value=qs
        )
        # Act
        response = view(request).render()
        #Assert
        assert response.status_code == 200
        assert len(json.loads(response.content)) == 3

    def test_retrieve(self, mocker, rf):
        currency = CurrencyFactory.build()
        expected_json = {
            'name': currency.name,
            'code': currency.code,
            'symbol': currency.symbol
        } 
        url = reverse('currency-detail', kwargs={'pk': currency.id})
        request = rf.get(url)
        mocker.patch.object(
            CurrencyViewSet, 'get_queryset', return_value=MockSet(currency)
        )
        view = CurrencyViewSet.as_view(
            {'get': 'retrieve'}
        )

        response = view(request, pk=currency.id).render()

        assert response.status_code == 200
        assert json.loads(response.content) == expected_json

    def test_create(self, mocker, rf):
        valid_data_dict = factory.build(
            dict,
            FACTORY_CLASS=CurrencyFactory
        )
        url = reverse('currency-list')
        request = rf.post(
            url,
            content_type='application/json',
            data=json.dumps(valid_data_dict)
        )
        mocker.patch.object(
            Currency, 'save'
        )
        view = CurrencyViewSet.as_view(
            {'post': 'create'}
        )

        response = view(request).render()

        assert response.status_code == 201
        assert json.loads(response.content) == valid_data_dict

    def test_update(self, mocker, rf):
        old_currency = CurrencyFactory.build()
        new_currency = CurrencyFactory.build()
        currency_dict = {
            'code': new_currency.code,
            'name': new_currency.name,
            'symbol': new_currency.symbol
        } 
        url = reverse('currency-detail', kwargs={'pk': old_currency.id})
        request = rf.put(
            url,
            content_type='application/json',
            data=json.dumps(currency_dict)
        )
        mocker.patch.object(
            CurrencyViewSet, 'get_object', return_value=old_currency
        )
        mocker.patch.object(
            Currency, 'save'
        )
        view = CurrencyViewSet.as_view(
            {'put': 'update'}
        )

        response = view(request, pk=old_currency.id).render()

        assert response.status_code == 200
        assert json.loads(response.content) == currency_dict

    @pytest.mark.parametrize('field',[
        ('code'),
        ('name'),
        ('symbol'),
    ])
    def test_partial_update(self, mocker, rf, field):
        currency = CurrencyFactory.build()
        currency_dict = {
            'code': currency.code,
            'name': currency.name,
            'symbol': currency.symbol
        } 
        valid_field = currency_dict[field]
        url = reverse('currency-detail', kwargs={'pk': currency.id})
        request = rf.patch(
            url,
            content_type='application/json',
            data=json.dumps({field: valid_field})
        )
        mocker.patch.object(
            CurrencyViewSet, 'get_object', return_value=currency
        )
        mocker.patch.object(
            Currency, 'save'
        )
        view = CurrencyViewSet.as_view(
            {'patch': 'partial_update'}
        )

        response = view(request).render()

        assert response.status_code == 200
        assert json.loads(response.content)[field] == valid_field

    def test_delete(self, mocker, rf):
        currency = CurrencyFactory.build()
        url = reverse('currency-detail', kwargs={'pk': currency.id})
        request = rf.delete(url)
        mocker.patch.object(
            CurrencyViewSet, 'get_object', return_value=currency
        )
        del_mock = mocker.patch.object(
            Currency, 'delete'
        )
        view = CurrencyViewSet.as_view(
            {'delete': 'destroy'}
        )

        response = view(request).render()

        assert response.status_code == 204
        assert del_mock.assert_called


class TestTransactionViewset:

    def test_list(self, mocker, rf):
        url = reverse('transaction-list')
        request = rf.get(url)
        qs = MockSet(
            FilledTransactionFactory.build(),
            FilledTransactionFactory.build(),
            FilledTransactionFactory.build()
        )
        mocker.patch.object(
            TransactionViewset, 'get_queryset', return_value=qs
        )
        view = TransactionViewset.as_view(
            {'get': 'list'}
        )

        response = view(request).render()

        assert response.status_code == 200
        assert len(json.loads(response.content)) == 3

    def test_create(self, mocker, rf):
        valid_data_dict = factory.build(
            dict,
            FACTORY_CLASS=UnfilledTransactionFactory
        )
        currency = valid_data_dict['currency']
        valid_data_dict['currency'] = currency.code
        url = reverse('transaction-list')
        request = rf.post(
            url,
            content_type='application/json',
            data=json.dumps(valid_data_dict)
        )
        retrieve_currency = mocker.Mock(return_value=currency)
        SlugRelatedField.to_internal_value = retrieve_currency

        mocker.patch.object(
            Transaction, 'save'
        )
        view = TransactionViewset.as_view(
            {'post': 'create'}
        )

        response = view(request).render()

        assert response.status_code == 201
        assert json.loads(response.content) == valid_data_dict

    def test_retrieve(self, api_client, mocker, ftd):
        transaction = FilledTransactionFactory.build()
        expected_json = ftd(transaction)
        url = reverse(
            'transaction-detail', kwargs={'pk': transaction.id}
        )
        TransactionViewset.get_queryset = mocker.Mock(
            return_value=MockSet(transaction)
        )

        response = api_client().get(url)

        assert response.status_code == 200
        assert json.loads(response.content) == expected_json

    def test_update(self, mocker, api_client, ftd):
        old_transaction = FilledTransactionFactory.build()
        new_transaction = FilledTransactionFactory.build()
        transaction_json = ftd(new_transaction, old_transaction)
        url = reverse(
            'transaction-detail',
            kwargs={'pk': old_transaction.id}
        )

        retrieve_currency = mocker.Mock(
            return_value=old_transaction.currency
        )
        SlugRelatedField.to_internal_value = retrieve_currency
        mocker.patch.object(
            TransactionViewset,
            'get_object',
            return_value=old_transaction
        )
        Transaction.save = mocker.Mock()

        response = api_client().put(
            url,
            data=transaction_json,
            format='json'            
        )

        assert response.status_code == 200
        assert json.loads(response.content) == transaction_json

    @pytest.mark.parametrize('field',[
        ('name'),
        ('billing_name'),
        ('billing_email'),
        ('email'),
        ('amount_in_cents'),
        ('message'),
    ])
    def test_partial_update(self, mocker, api_client, field):
        old_transaction = FilledTransactionFactory.build()
        new_transaction = FilledTransactionFactory.build()
        valid_field = {
            field: new_transaction.__dict__[field]
        }
        url = reverse(
            'transaction-detail',
            kwargs={'pk': old_transaction.id}
        )

        SlugRelatedField.to_internal_value = mocker.Mock(
            return_value=old_transaction.currency
        )
        mocker.patch.object(
            TransactionViewset,
            'get_object',
            return_value=old_transaction
        )
        Transaction.save = mocker.Mock()

        response = api_client().patch(
            url,
            data=valid_field,
            format='json'
        )

        assert response.status_code == 200
        assert json.loads(response.content)[field] == valid_field[field]

    def test_delete(self, mocker, api_client):
        transaction = FilledTransactionFactory.build()
        url = reverse('transaction-detail', kwargs={'pk': transaction.id})
        mocker.patch.object(
            TransactionViewset, 'get_object', return_value=transaction
        )
        del_mock = mocker.patch.object(
            Transaction, 'delete'
        )

        response = api_client().delete(
            url
        )

        assert response.status_code == 204
        assert del_mock.assert_called

테스트 스위트의 Coverage

우리는 코드에 다양한 논리적 분기가 있을 것이므로, pytest-cov 플러그인을 사용하여 테스트가 커버하는 코드의 양을 백분율로 나타내는 "커버리지"를 테스트할 수 있다. 테스트의 커버리지를 보기 위해서는 --cov 명령어를 사용해야 한다.

커버리지는 코드가 고장 날지 여부를 알려주지 않고, 단지 테스트로 얼마나 많이 커버했는지만 알려준다. 이는 만든 테스트의 관련성과는 무관하다.

setup.cfg 파일의 [coverage:run] 섹션에 수동으로 커버리지 설정을 수정하는 것이 좋다(또는 독립적인 파일을 원한다면 .coveragerc 파일의 [run] 섹션 안에 설정). 우리가 커버리지 테스트를 원하는 디렉토리와 그 안에서 제외하고 싶은 파일을 설정할 수 있습니다. 저는 모든 앱을 apps 디렉토리 안에 두고 있기 때문에, 제 setup.cfg는 다음과 같이 보일 것이다.

[tool:pytest]
...

[coverage:run]
source=apps
omit=*/migrations/*,

pytest --cov --cov --cov-config= setup.cfg(addopts에 포함시킬 수 있음)를 실행하면 다음과 같은 출력이 나올 수 있다.

100%의 커버리지를 달성했다고 해서 그 자체로 의미가 있는 것은 아니다. 다시 말해, 만약 관련 없는 코드 조각을 테스트하지 않아서 90%의 커버리지를 얻었다면, 그 커버리지는 100%만큼이나 좋을 수 있다.

주의사항: pytest-cov 플러그인은 VSCode 디버거와 호환되지 않는 것으로 보고되었으므로, addopts에서 이 명령어를 제거하거나 때로는 프로젝트에서 완전히 제거할 필요가 있을 수 있습니다.

flaky 테스트에 대하여

플래키 테스트는 서로 간에 테스트를 적절히 격리하지 않아 발생한다. 이는 엔드 투 엔드 테스트를 실행할 때는 받아들일 수 있지만, 단위 테스트에서 발생한다면 심각한 경고 신호가 되어야 한다.

profile
안녕하세요! 질문과 피드백은 언제든지 환영입니다:)

0개의 댓글