[C++] 이진탐색트리 구현(raw pointer, smart pointer)

haeryong·2023년 5월 17일
0

참고한 강의


STL 컨테이너 중 std::map을 구현해보는 실습을 진행하였다.

1 Struct, Class

1-1 Pair

  • std::map의 경우 각 노드마다 std::pair를 저장하고 있다. pair의 first가 key, second가 value이다.
template <typename T1, typename T2>
struct Pair
{
	T1 first;
	T2 second;

	...
};

1-2 TreeNode

  • Pair에 key, value값을 저장하고 {부모노드, 왼쪽자식, 오른쪽 자식}을 크기가 3인 포인터 배열로 가지고 있다.
  • 포인터 배열에서 parent, left, right의 인덱스는 enum class를 정의해 사용하였다.
enum class NODE_TYPE
{
	PARENT,
	LEFT,
	RIGHT,
	END
};

template <typename T1, typename T2>
struct TreeNode
{
	Pair<T1, T2> pair;
	TreeNode* nodeArray[static_cast<int>(NODE_TYPE::END)];

	...
};

1-3 BinarySearchTree

  • 이진탐색트리 클래스. 루트노드의 포인터를 멤버변수로 가지고 있다.
  • 내부에 iterator 클래스를 가지고 있다.
template <typename T1, typename T2>
class BinarySearchTree
{
	...
    
	class iterator
	{
    	...
	};

	...
	
private:
	TreeNode<T1, T2>* rootNode;
	int size;

};

1-4 BinarySearchTree<T1, T2>::iterator

  • 현재 iterator가 가리키는 노드의 포인터와 이진탐색트리(자기 자신)의 포인터를 가지고 있다.
  • ++, --, *, -> 연산들을 지원한다.
	class iterator
	{
    	...
	private:
		BinarySearchTree<T1, T2>* bst;
		TreeNode<T1, T2>* node;
	};

2 주요 Class method

2-1 insert(const Pair<T1, T2>& input)

2-2 find(const T1& key)

2-3 erase(const iterator& iter)

3 사용

#include <iostream>

#include "BinarySearchTree.h"

int main()
{
    BinarySearchTree<int, int> bst;
    bst.insert(makePair(100, 2));
    bst.insert(makePair(150, 2));
    bst.insert(makePair(50, 2));
    bst.insert(makePair(25, 2));
    bst.insert(makePair(75, 2));
    bst.insert(makePair(125, 2));
    bst.insert(makePair(175, 2));

    BinarySearchTree<int, int>::iterator iter = bst.begin();

    iter = bst.find(120);
    if (iter != bst.end())
    {
        std::cout << iter->first << ", " << iter->second << "\n";
    }
    iter = bst.find(150);
    if (iter != bst.end())
    {
        std::cout << iter->first << ", " << iter->second << "\n";
    }

    for (auto it = bst.begin(); it != bst.end(); ++it)
    {
        std::cout << it->first << ", " << it->second << "\n";
    }
    
    for (auto it = bst.rbegin(); it != bst.rend(); --it)
    {
        std::cout << it->first << ", " << it->second << "\n";
    }

    iter = bst.find(150);
    iter = bst.erase(iter);

    return 0;
}

4 전체 코드

4-1 Raw pointer 버전

  • BinarySearchTree.h
#pragma once

#include <exception>

enum class NODE_TYPE
{
	PARENT,
	LEFT,
	RIGHT,
	END
};

template <typename T1, typename T2>
struct Pair
{
	T1 first;
	T2 second;

	Pair()
	{}

	Pair(T1 first, T2 second)
		: first(first)
		, second(second)
	{}
};

template <typename T1, typename T2>
Pair<T1, T2> makePair(const T1& first, const T2& second)
{
	return Pair<T1, T2>(first, second);
}

template <typename T1, typename T2>
struct TreeNode
{
	Pair<T1, T2> pair;
	TreeNode* nodeArray[static_cast<int>(NODE_TYPE::END)];

	TreeNode()
		: pair()
		, nodeArray{nullptr, nullptr, nullptr}
	{}

	TreeNode(const Pair<T1, T2>& pair)
		: pair(pair)
		, nodeArray{nullptr, nullptr, nullptr}
	{}

	TreeNode(const Pair<T1, T2>& pair, TreeNode* parent, TreeNode* left, TreeNode* right)
		: pair(pair)
		, nodeArray{parent, left, right}
	{}

	~TreeNode()
	{
	}

	bool isRoot()
	{
		return (nodeArray[static_cast<int>(NODE_TYPE::PARENT)] == nullptr);
	}

	bool isLeftChild() 
	{
		return (nodeArray[static_cast<int>(NODE_TYPE::PARENT)]->nodeArray[static_cast<int>(NODE_TYPE::LEFT)] == this);
	}
	bool isRightChild() 
	{
		return (nodeArray[static_cast<int>(NODE_TYPE::PARENT)]->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] == this);
	}

	bool isLeaf()
	{
		return (nodeArray[static_cast<int>(NODE_TYPE::LEFT)] == nullptr && nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] == nullptr);
	}

	bool isFull()
	{
		return (nodeArray[static_cast<int>(NODE_TYPE::LEFT)] != nullptr && nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] != nullptr);
	}
};

template <typename T1, typename T2>
class BinarySearchTree
{
public:
	BinarySearchTree();

	~BinarySearchTree()
	{
		auto iter = this->begin();
		while (iter != this->end())
		{
			iter = this->erase(iter);
		}
	}

	int getSize() const { return size; }

	TreeNode<T1, T2>* getInOrderSuccessor(TreeNode<T1, T2>* node);
	TreeNode<T1, T2>* getInOrderPredecessor(TreeNode<T1, T2>* node);

	class iterator
	{
	public:
		iterator()
			: bst(nullptr)
			, node(nullptr)
		{}

		iterator(BinarySearchTree<T1, T2>* bst, TreeNode<T1, T2>* node)
			: bst(bst)
			, node(node)
		{}

		bool operator ==(const iterator& other)
		{
			return (bst == other.bst && node == other.node);
		}

		bool operator !=(const iterator& other)
		{
			return !(*this == other);
		}

		const Pair<T1, T2>& operator *()
		{
			if (node == nullptr)
			{
				throw std::runtime_error("Attempted to dereference a null pointer.");
			}
			return node->pair;
		}

		const Pair<T1, T2>* operator ->()
		{
			if (node == nullptr)
			{
				throw std::runtime_error("Attempted to dereference a null pointer.");
			}
			return &(node->pair);
		}

		iterator& operator ++()
		{
			node = bst->getInOrderSuccessor(node);
			return *this;
		}

		iterator& operator --()
		{
			node = bst->getInOrderPredecessor(node);
			return *this;
		}

	private:
		BinarySearchTree<T1, T2>* bst;
		TreeNode<T1, T2>* node;

		friend class BinarySearchTree<T1, T2>;
	};

	iterator begin();
	iterator end();
	iterator rbegin();
	iterator rend();

	bool insert(const Pair<T1, T2>& input);
	iterator find(const T1& key);
	iterator erase(const iterator& iter);

private:
	TreeNode<T1, T2>* deleteNode(TreeNode<T1, T2>* targetNode);

	TreeNode<T1, T2>* rootNode;
	int size;

};

template <typename T1, typename T2>
inline BinarySearchTree<T1, T2>::BinarySearchTree()
	: rootNode(nullptr)
	, size(0)
{}

template <typename T1, typename T2>
inline bool BinarySearchTree<T1, T2>::insert(const Pair<T1, T2>& input)
{
	TreeNode<T1, T2>* newNode = new TreeNode<T1, T2>(input);
	
	if (rootNode == nullptr)
	{
		rootNode = newNode;
		++size;
		return true;
	}

	TreeNode<T1, T2>* currentNode = rootNode;
	NODE_TYPE nodeType = NODE_TYPE::END;

	while (true)
	{

		if (currentNode->pair.first < newNode->pair.first)
		{
			nodeType = NODE_TYPE::RIGHT;
		}
		else if (currentNode->pair.first > newNode->pair.first)
		{
			nodeType = NODE_TYPE::LEFT;
		}
		else
		{
			return false;
		}
		
		if (currentNode->nodeArray[static_cast<int>(nodeType)] == nullptr)
		{
			currentNode->nodeArray[static_cast<int>(nodeType)] = newNode;
			newNode->nodeArray[static_cast<int>(NODE_TYPE::PARENT)] = currentNode;
			break;
		}
		else
		{
			currentNode = currentNode->nodeArray[static_cast<int>(nodeType)];
		}
	}

	++size;
	return true;
}

template<typename T1, typename T2>
inline TreeNode<T1, T2>* BinarySearchTree<T1, T2>::getInOrderSuccessor(TreeNode<T1, T2 >* node)
{
	if (node == nullptr)
	{
		return nullptr;
	}

	TreeNode<T1, T2>* successor = nullptr;
	if (node->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] != nullptr)
	{
		successor = node->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)];
		while (successor->nodeArray[static_cast<int>(NODE_TYPE::LEFT)] != nullptr)
		{
			successor = successor->nodeArray[static_cast<int>(NODE_TYPE::LEFT)];
		}
	}
	else
	{
		successor = node;
		
		while (true)
		{
			if (successor->isRoot())
			{
				return nullptr;
			}

			if (successor->isLeftChild())
			{	
				successor = successor->nodeArray[static_cast<int>(NODE_TYPE::PARENT)];
				break;
			}
			else
			{
				successor = successor->nodeArray[static_cast<int>(NODE_TYPE::PARENT)];
			}

		}
	}

	return successor;
}

template<typename T1, typename T2>
inline TreeNode<T1, T2>* BinarySearchTree<T1, T2>::getInOrderPredecessor(TreeNode<T1, T2>* node)
{
	if (node == nullptr)
	{
		return nullptr;
	}
	
	TreeNode<T1, T2>* predecessor = nullptr;
	if (node->nodeArray[static_cast<int>(NODE_TYPE::LEFT)] != nullptr)
	{
		predecessor = node->nodeArray[static_cast<int>(NODE_TYPE::LEFT)];
		while (predecessor->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] != nullptr)
		{
			predecessor = predecessor->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)];
		}
	}
	else
	{
		predecessor = node;

		while (true)
		{
			if (predecessor->isRoot())
			{
				return nullptr;
			}

			if (predecessor->isRightChild())
			{
				predecessor = predecessor->nodeArray[static_cast<int>(NODE_TYPE::PARENT)];
				break;
			}
			else
			{
				predecessor = predecessor->nodeArray[static_cast<int>(NODE_TYPE::PARENT)];
			}

		}
	}

	return predecessor;

}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::begin()
{	
	TreeNode<T1, T2>* currentNode = rootNode;

	while (currentNode->nodeArray[static_cast<int>(NODE_TYPE::LEFT)] != nullptr)
	{
		currentNode = currentNode->nodeArray[static_cast<int>(NODE_TYPE::LEFT)];
	}
	return iterator(this, currentNode);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::rbegin()
{
	TreeNode<T1, T2>* currentNode = rootNode;

	while (currentNode->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] != nullptr)
	{
		currentNode = currentNode->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)];
	}
	return iterator(this, currentNode);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::end()
{
	return iterator(this, nullptr);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::rend()
{
	return iterator(this, nullptr);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::find(const T1& key)
{

	if (rootNode == nullptr)
	{
		return end();
	}

	TreeNode<T1, T2>* currentNode = rootNode;
	NODE_TYPE nodeType = NODE_TYPE::END;

	while (true)
	{

		if (currentNode->pair.first < key)
		{
			nodeType = NODE_TYPE::RIGHT;
		}
		else if (currentNode->pair.first > key)
		{
			nodeType = NODE_TYPE::LEFT;
		}
		else
		{
			break;
		}

		if (currentNode->nodeArray[static_cast<int>(nodeType)] == nullptr)
		{
			return end();
		}
		else
		{
			currentNode = currentNode->nodeArray[static_cast<int>(nodeType)];
		}
	}

	return iterator(this, currentNode);
}

template<typename T1, typename T2>
inline TreeNode<T1, T2>* BinarySearchTree<T1, T2>::deleteNode(TreeNode<T1, T2>* targetNode)
{
	TreeNode<T1, T2>* successor = getInOrderSuccessor(targetNode);

	if (targetNode->isLeaf())
	{
		if (targetNode == rootNode)
		{
			rootNode = nullptr;
		}
		else
		{
			if (targetNode->isLeftChild())
			{
				targetNode->nodeArray[static_cast<int>(NODE_TYPE::PARENT)]->nodeArray[static_cast<int>(NODE_TYPE::LEFT)] = nullptr;
			}
			else
			{
				targetNode->nodeArray[static_cast<int>(NODE_TYPE::PARENT)]->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] = nullptr;
			}
		}

		delete targetNode;
		--size;
	}
	else if (targetNode->isFull())
	{
		targetNode->pair = successor->pair;
		deleteNode(successor);
		successor = targetNode;
	}
	else
	{
		NODE_TYPE childType = NODE_TYPE::LEFT;
		if (targetNode->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] != nullptr)
		{
			childType = NODE_TYPE::RIGHT;
		}

		if (targetNode == rootNode)
		{
			rootNode = targetNode->nodeArray[static_cast<int>(childType)];
			targetNode->nodeArray[static_cast<int>(childType)]->nodeArray[static_cast<int>(NODE_TYPE::PARENT)] = nullptr;
		}
		else
		{
			if (targetNode->isLeftChild())
			{
				targetNode->nodeArray[static_cast<int>(NODE_TYPE::PARENT)]->nodeArray[static_cast<int>(NODE_TYPE::LEFT)] = targetNode->nodeArray[static_cast<int>(childType)];
			}
			else
			{
				targetNode->nodeArray[static_cast<int>(NODE_TYPE::PARENT)]->nodeArray[static_cast<int>(NODE_TYPE::RIGHT)] = targetNode->nodeArray[static_cast<int>(childType)];
			}
			targetNode->nodeArray[static_cast<int>(childType)]->nodeArray[static_cast<int>(NODE_TYPE::PARENT)] = targetNode->nodeArray[static_cast<int>(NODE_TYPE::PARENT)];
		}

		delete targetNode;
		--size;
	}
	return successor;
}


template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::erase(const iterator& iter)
{
	if (this != iter.bst)
	{
		throw std::invalid_argument("Invalid iterator");
	}
	
	TreeNode<T1, T2>* successor = deleteNode(iter.node);
	return iterator(this, successor);
}


4-2 Smart Pointer 버전

#pragma once

#include <array>
#include <exception>
#include <memory>

enum class CHILD_TYPE
{
	LEFT,
	RIGHT,
	END
};

template <typename T1, typename T2>
struct Pair
{
	T1 first;
	T2 second;

	constexpr Pair()
		: first(T1())
		, second(T2())
	{}

	Pair(const T1& first, const T2& second)
		: first(first)
		, second(second)
	{}
};

template <typename T1, typename T2>
Pair<T1, T2> makePair(const T1& first, const T2& second)
{
	return Pair<T1, T2>(first, second);
}

template <typename T1, typename T2>
struct TreeNode
{
	Pair<T1, T2> pair;
	std::weak_ptr<TreeNode> parentNode;
	std::array<std::shared_ptr<TreeNode>, static_cast<int>(CHILD_TYPE::END)> childNodes;

	TreeNode()
		: pair()
		, parentNode()
		, childNodes{ nullptr, nullptr }
	{}

	TreeNode(const Pair<T1, T2>& pair)
		: pair(pair)
		, parentNode()
		, childNodes{ nullptr, nullptr }
	{}

	TreeNode(const Pair<T1, T2>& pair, std::shared_ptr<TreeNode> parent, std::shared_ptr<TreeNode> left, std::shared_ptr<TreeNode> right)
		: pair(pair)
		, parentNode(parent)
		, childNodes{ left, right }
	{}

	~TreeNode()
	{
	}

	bool isRoot()
	{
		return (parentNode.expired());
	}

	bool isLeftChild()
	{
		if (isRoot())
		{
			return false;
		}
		return (parentNode.lock()->childNodes[static_cast<int>(CHILD_TYPE::LEFT)].get() == this);
	}
	bool isRightChild()
	{
		if (isRoot())
		{
			return false;
		}
		return (parentNode.lock()->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)].get() == this);
	}

	bool isLeaf()
	{
		return (childNodes[static_cast<int>(CHILD_TYPE::LEFT)] == nullptr && childNodes[static_cast<int>(CHILD_TYPE::RIGHT)] == nullptr);
	}

	bool isFull()
	{
		return (childNodes[static_cast<int>(CHILD_TYPE::LEFT)] != nullptr && childNodes[static_cast<int>(CHILD_TYPE::RIGHT)] != nullptr);
	}
};

template <typename T1, typename T2>
class BinarySearchTree
{
public:
	BinarySearchTree();

	~BinarySearchTree()
	{
	}

	int getSize() const { return size; }

	std::shared_ptr<TreeNode<T1, T2>> getInOrderSuccessor(std::shared_ptr<TreeNode<T1, T2>> node);
	std::shared_ptr<TreeNode<T1, T2>> getInOrderPredecessor(std::shared_ptr<TreeNode<T1, T2>> node);

	class iterator
	{
	public:
		iterator()
			: bst(nullptr)
			, node(nullptr)
		{}

		iterator(BinarySearchTree<T1, T2>* bst, std::shared_ptr<TreeNode<T1, T2>> node)
			: bst(bst)
			, node(node)
		{}

		bool operator ==(const iterator& other)
		{
			return (bst == other.bst && node == other.node);
		}

		bool operator !=(const iterator& other)
		{
			return !(*this == other);
		}

		const Pair<T1, T2>& operator *()
		{
			if (node == nullptr)
			{
				throw std::runtime_error("Attempted to dereference a null pointer.");
			}
			return node->pair;
		}

		const std::shared_ptr<Pair<T1, T2>> operator ->()
		{
			if (node == nullptr)
			{
				throw std::runtime_error("Attempted to dereference a null pointer.");
			}
			return std::make_shared<Pair<T1, T2>>(node->pair);
		}

		iterator& operator ++()
		{
			node = bst->getInOrderSuccessor(node);
			return *this;
		}

		iterator& operator --()
		{
			node = bst->getInOrderPredecessor(node);
			return *this;
		}

	private:
		BinarySearchTree<T1, T2>* bst;
		std::shared_ptr<TreeNode<T1, T2>> node;

		friend class BinarySearchTree<T1, T2>;
	};

	iterator begin();
	iterator end();
	iterator rbegin();
	iterator rend();

	bool insert(const Pair<T1, T2>& input);
	iterator find(const T1& key);
	iterator erase(const iterator& iter);

private:
	std::shared_ptr<TreeNode<T1, T2>> deleteNode(std::shared_ptr<TreeNode<T1, T2>> targetNode);
	std::shared_ptr<TreeNode<T1, T2>> rootNode;
	int size;

};

template <typename T1, typename T2>
inline BinarySearchTree<T1, T2>::BinarySearchTree()
	: rootNode(nullptr)
	, size(0)
{}

template <typename T1, typename T2>
inline bool BinarySearchTree<T1, T2>::insert(const Pair<T1, T2>& input)
{
	auto newNode = std::make_shared<TreeNode<T1, T2>>(input);

	if (rootNode == nullptr)
	{
		rootNode = newNode;
		++size;
		return true;
	}

	std::shared_ptr<TreeNode<T1, T2>> currentNode = rootNode;
	CHILD_TYPE childType = CHILD_TYPE::END;

	while (true)
	{

		if (currentNode->pair.first < newNode->pair.first)
		{
			childType = CHILD_TYPE::RIGHT;
		}
		else if (currentNode->pair.first > newNode->pair.first)
		{
			childType = CHILD_TYPE::LEFT;
		}
		else
		{
			return false;
		}

		if (currentNode->childNodes[static_cast<int>(childType)] == nullptr)
		{
			currentNode->childNodes[static_cast<int>(childType)] = newNode;
			newNode->parentNode = currentNode;
			break;
		}
		else
		{
			currentNode = currentNode->childNodes[static_cast<int>(childType)];
		}
	}

	++size;
	return true;
}

template<typename T1, typename T2>
inline std::shared_ptr<TreeNode<T1, T2>> BinarySearchTree<T1, T2>::getInOrderSuccessor(std::shared_ptr<TreeNode<T1, T2>> node)
{
	if (node == nullptr)
	{
		return nullptr;
	}

	std::shared_ptr<TreeNode<T1, T2>> successor;
	if (node->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)] != nullptr)
	{
		successor = node->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)];
		while (successor->childNodes[static_cast<int>(CHILD_TYPE::LEFT)] != nullptr)
		{
			successor = successor->childNodes[static_cast<int>(CHILD_TYPE::LEFT)];
		}
	}
	else
	{
		successor = node;

		while (true)
		{
			if (successor->isRoot())
			{
				return nullptr;
			}

			if (successor->isLeftChild())
			{
				successor = successor->parentNode.lock();
				break;
			}
			else
			{
				successor = successor->parentNode.lock();
			}

		}
	}

	return successor;
}

template<typename T1, typename T2>
inline std::shared_ptr<TreeNode<T1, T2>> BinarySearchTree<T1, T2>::getInOrderPredecessor(std::shared_ptr<TreeNode<T1, T2>> node)
{
	if (node == nullptr)
	{
		return nullptr;
	}

	std::shared_ptr<TreeNode<T1, T2>> predecessor;
	if (node->childNodes[static_cast<int>(CHILD_TYPE::LEFT)] != nullptr)
	{
		predecessor = node->childNodes[static_cast<int>(CHILD_TYPE::LEFT)];
		while (predecessor->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)] != nullptr)
		{
			predecessor = predecessor->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)];
		}
	}
	else
	{
		predecessor = node;

		while (true)
		{
			if (predecessor->isRoot())
			{
				return nullptr;
			}

			if (predecessor->isRightChild())
			{
				predecessor = predecessor->parentNode.lock();
				break;
			}
			else
			{
				predecessor = predecessor->parentNode.lock();
			}

		}
	}
	return predecessor;
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::begin()
{
	std::shared_ptr<TreeNode<T1, T2>> currentNode = rootNode;

	while (currentNode->childNodes[static_cast<int>(CHILD_TYPE::LEFT)] != nullptr)
	{
		currentNode = currentNode->childNodes[static_cast<int>(CHILD_TYPE::LEFT)];
	}
	return iterator(this, currentNode);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::rbegin()
{
	std::shared_ptr<TreeNode<T1, T2>> currentNode = rootNode;

	while (currentNode->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)] != nullptr)
	{
		currentNode = currentNode->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)];
	}
	return iterator(this, currentNode);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::end()
{
	return iterator(this, nullptr);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::rend()
{
	return iterator(this, nullptr);
}

template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::find(const T1& key)
{

	if (rootNode == nullptr)
	{
		return end();
	}

	std::shared_ptr<TreeNode<T1, T2>> currentNode = rootNode;
	CHILD_TYPE childType = CHILD_TYPE::END;

	while (true)
	{

		if (currentNode->pair.first < key)
		{
			childType = CHILD_TYPE::RIGHT;
		}
		else if (currentNode->pair.first > key)
		{
			childType = CHILD_TYPE::LEFT;
		}
		else
		{
			break;
		}

		if (currentNode->childNodes[static_cast<int>(childType)] == nullptr)
		{
			return end();
		}
		else
		{
			currentNode = currentNode->childNodes[static_cast<int>(childType)];
		}
	}

	return iterator(this, currentNode);
}

template<typename T1, typename T2>
inline std::shared_ptr<TreeNode<T1, T2>> BinarySearchTree<T1, T2>::deleteNode(std::shared_ptr<TreeNode<T1, T2>> targetNode)
{
	std::shared_ptr<TreeNode<T1, T2>> successor = getInOrderSuccessor(targetNode);

	if (targetNode->isLeaf())
	{
		if (targetNode == rootNode)
		{
			rootNode.reset();
		}
		else
		{
			if (targetNode->isLeftChild())
			{
				targetNode->parentNode.lock()->childNodes[static_cast<int>(CHILD_TYPE::LEFT)].reset();
			}
			else
			{
				targetNode->parentNode.lock()->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)].reset();
			}
		}
		--size;
	}
	else if (targetNode->isFull())
	{
		targetNode->pair = successor->pair;
		deleteNode(successor);
		successor = targetNode;
	}
	else
	{
		CHILD_TYPE childType = CHILD_TYPE::LEFT;
		if (targetNode->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)] != nullptr)
		{
			childType = CHILD_TYPE::RIGHT;
		}

		if (targetNode == rootNode)
		{
			rootNode = targetNode->childNodes[static_cast<int>(childType)];
			targetNode->childNodes[static_cast<int>(childType)]->parentNode.reset();
		}
		else
		{
			if (targetNode->isLeftChild())
			{
				targetNode->parentNode.lock()->childNodes[static_cast<int>(CHILD_TYPE::LEFT)] = targetNode->childNodes[static_cast<int>(childType)];
			}
			else
			{
				targetNode->parentNode.lock()->childNodes[static_cast<int>(CHILD_TYPE::RIGHT)] = targetNode->childNodes[static_cast<int>(childType)];
			}
			targetNode->childNodes[static_cast<int>(childType)]->parentNode = targetNode->parentNode;
		}
		--size;
	}
	return successor;
}


template<typename T1, typename T2>
inline typename BinarySearchTree<T1, T2>::iterator BinarySearchTree<T1, T2>::erase(const iterator& iter)
{
	if (this != iter.bst)
	{
		throw std::invalid_argument("Invalid iterator");
	}

	std::shared_ptr<TreeNode<T1, T2>> successor = deleteNode(iter.node);
	return iterator(this, successor);
}

0개의 댓글