이번에 알고리즘 문제를 풀며 익숙했던 C 스타일의 반복문 대신, 좀 더 'Java스러운' 방식으로 풀고 싶어 Stream을 적극적으로 활용했다. 그러던 중 헷갈렸던 부분들을 정리해 두고자 이 글을 쓴다.
Stream API는 java.util.stream 패키지에 포함되어 있으며, 공식 Java Doc에서는 아래와 같이 설명한다.
Classes to support functional-style operations on streams of elements, such as map-reduce transformations on collections.
java.util.stream (Java Platform SE 8 )
이어지는 java.util.stream.Stream의 설명은 다음과 같다.
A sequence of elements supporting sequential and parallel aggregate operations.
Stream (Java Platform SE 8 )
즉, 컬렉션 등의 데이터 소스에서 요소들을 스트림 형태로 꺼내 함수형 스타일의 연산을 적용할 수 있고, 순차적 또는 병렬 집계 연산을 지원하는 요소들의 시퀀스 라고 볼 수 있다. 여기서 sequential과 parallel은 각각 순차 스트림과 병렬 스트림처럼 순차 또는 병렬 방식으로 집계 연산을 수행할 수 있다는 의미이며, 이 시퀀스는 데이터 소스에 따라 순서가 정의될 수도 있고, 정의되지 않을 수도 있다.
이는 Java 8에서 도입된 기능으로, Collection, 배열, I/O 등 데이터 소스에서 추출한 요소들에 대해 map, filter, sorted, reduce와 같은 함수형 연산들을 파이프라인 형태로 구성할 수 있게 해 준다. 이 파이프라인은 중간 연산과 최종 연산으로 이루어지며, 중간 연산은 지연(lazy) 평가되고, 최종 연산이 호출될 때 실제로 실행된다.
Java 8의 람다식, 메서드 참조, 함수형 인터페이스와 함께, Stream API는 기존의 객체지향적인 Java 코드에서도 함수형 스타일로 데이터를 처리할 수 있게 해 주는 핵심 구성 요소라고 이해할 수 있다.
Stream 연산은 크게 두 종류로 나뉜다.
map, filter, sorted, distinct, limit, skip 등forEach, collect, reduce, count, anyMatch, findFirst 등이 연산들이 실제로 어떻게 적용되는지 아래 코드 예시와 함께 알아보자.
// Stream API 예시
Arrays.asList(1, 2, 3, 4).stream()
.filter(n -> n > 2) // 3, 4 만 통과
.map(n -> n * 10) // 30, 40 으로 변환
.collect(Collectors.toList()); // [30, 40]
// 결과 = [30, 40]
이 코드는 Collection에서 stream() 메서드를 활용하여 Stream<T>으로 변환 후, filter와 map이라는 중간 연산을 체인으로 이어 붙이고, 마지막 collect라는 최종 연산이 호출되는 시점에 비로소 전체 파이프라인이 실행된다.
이제 이 코드가 Stream에서 어떻게 지연 평가되는지, 아래 다이어그램을 통해 확인해보자.
위 다이어그램은 이해를 위해 단계별로 나눠 그린 것이고, 실제로는 각 요소가 1 -> filter -> map -> collect 순서로 차례대로 파이프라인을 통과한다. 이 예시는 단일 스레드에서 순차적으로 실행되는 순차 스트림이다.
Stream은 일회성(one-shot)이라는 중요한 특징을 가진다.
Stream<String> stream = list.stream();
long count = stream.count();
// 이미 소비된 stream에 대해 다시 최종 연산을 호출하면 IllegalStateException이 발생한다.
stream.forEach(System.out::println); // IllegalStateException 발생
따라서 동일한 데이터 소스에 대해 여러 번 연산이 필요하다면, 매번 데이터 소스에서 새로운 Stream을 생성해야 한다.
long count = list.stream().count(); // 첫 번째 연산
list.stream().forEach(System.out::println); // 두 번째 연산 (새 스트림)
이런 특성 때문에 Stream은 흔히 “데이터 자체가 아니라, 데이터에 대한 일회용 뷰(view)” 로 보는 게 이해하기 편하다.
Stream API에 대한 설명을 마쳤으므로 이제 Java에서 어떻게 구현되어 있는지와 사용 예제에 대해 알아보자.
Stream API에는 기초가 되는 BaseStream<T, S> 인터페이스가 있으며, 이를 상속하는 Stream<T>와 일부 원시 자료형인 int, long, double에 대해 특화된 IntStream, LongStream, DoubleStream 이 있다.
BaseStream의 선언은 대략 아래와 같이 생겼다.
public interface BaseStream<T, S extends BaseStream<T, S>>
extends AutoCloseable {
S sequential();
S parallel();
S unordered();
S onClose(Runnable closeHandler);
void close();
}
이로 인해 BaseStream<T, S>를 상속하는 구체적인 스트림 인터페이스들은 항상 자기 자신의 스트림 타입을 반환할 수 있다는 것이다.
Stream<T>의 중간 연산들은 다시 Stream<T>를 반환한다.IntStream의 중간 연산들은 다시 IntStream를 반환한다.LongStream, DoubleStream도 동일하다.이어서 실제로 BaseStream<T, S>를 상속받는 인터페이스들을 알아보자.
Stream<T>Stream<T>는 객체 타입을 다루기 위한 스트림이다.
일반적으로는 다음과 같은 방식으로 얻을 수 있다.
Collection 구현체(List<>, Set<> 등)의 stream()/parallelStream()List<String> names = Arrays.asList("Kim", "Lee", "Jang");
Stream<String> stream = names.stream();Stream.of(...)를 이용해 직접 생성Stream<String> stream = Stream.of("Kim", "Lee", "Jang");Arrays.stream(...) 사용String[] names = {"Kim", "Lee", "Jang"};
Stream<String> stream = Arrays.stream(names);사용자 정의 타입도 마찬가지로, 예를 들어 List<User>에 담겨 있으면 Stream<User>로 다룰 수 있다.아래 예제들에서는 설명의 편의를 위해 다음 리스트를 기본으로 사용한다고 가정한다.
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5);
Stream<T> 주요 연산 예제Intermediate Operations
filter()
조건에 맞는 요소만 걸러낸다.
List<Integer> evens = numbers.stream()
.filter(n -> n % 2 == 0)
.collect(Collectors.toList());
map()
요소를 다른 형태로 변환한다.
List<String> labels = numbers.stream()
.map(n -> "num=" + n)
.collect(Collectors.toList());
mapTo...()
객체 스트림을 기본형 스트림으로 변환한다.
int sum = numbers.stream()
.mapToInt(Integer::intValue)
.sum();
flatMap()
중첩된 스트림(또는 컬렉션)을 평탄화 한다.
List<String> sentences = Arrays.asList("a b", "c d");
List<String> words = sentences.stream()
.flatMap(s -> Arrays.stream(s.split(" ")))
.collect(Collectors.toList());
flatMapTo...()
평탄화하면서 동시에 기본형 스트림으로 변환한다.
Stream<String> strings = Stream.of("abc", "de");
IntStream codePoints = strings.flatMapToInt(String::chars);
distinct()
중복을 제거한다.
내부적으로
equals()를 기준으로 중복을 제거하므로 사용자 지정 클래스인 경우equals()와hashCode()를 모두 오버라이드해야 올바르게 작동한다.
List<Integer> distinct = numbers.stream()
.distinct()
.collect(Collectors.toList());
sorted()
정렬된 스트림을 만든다.
List<Integer> sortedDesc = numbers.stream()
.sorted(Comparator.reverseOrder())
.collect(Collectors.toList());
peek()
스트림을 변경하지 않고 값을 확인할 수 있다.
List<Integer> result = numbers.stream()
.peek(n -> System.out.println("before filter: " + n))
.filter(n -> n % 2 == 0)
.peek(n -> System.out.println("after filter: " + n))
.collect(Collectors.toList());
limit()
앞에서부터 일정 개수만 남긴다.
List<Integer> first3 = numbers.stream()
.limit(3)
.collect(Collectors.toList());
skip()
앞의 N개 요소를 건너뛰고, 이후 요소들만 남긴다.
List<Integer> skipped = numbers.stream()
.skip(2)
.collect(Collectors.toList());
Terminal Operations
forEach()
각 요소에 대해 동작을 수행한다.
순차 스트림에서는 원본 순서(Encounter-Order)가 유지되지만,
병렬 스트림이나 unordered() 이후에는 순서가 보장되지 않을 수 있다.
numbers.stream()
.forEach(System.out::println);
forEachOrdered()
순차 스트림 혹은 병렬 스트림 관계 없이 Encounter-Order를 지켜서 처리한다.
numbers.parallelStream()
.forEachOrdered(System.out::println);
toArray()
배열로 변환한다.
Integer[] arr = numbers.stream()
.toArray(Integer[]::new);
collect()
스트림을 다른 컬렉션이나 형식으로 모은다.
List<Integer> collected = numbers.stream()
.filter(n -> n > 2)
.collect(Collectors.toList());
reduce()
스트림의 모든 요소를 하나의 값으로 접어서(누적해서) 만든다.
흔히 말하는 “fold” 연산으로, identity(초기값)과 accumulator(누적 함수)를 사용한다.
// 합계 구하기 (1 + 2 + 3 + 4 + 5)
int sum = numbers.stream()
.reduce(0, (acc, n) -> acc + n);
// (((0 + 1) + 2) + 3) + 4 ...
// 곱 계산 (identity를 1로 두고 곱셈)
int product = numbers.stream()
.reduce(1, (acc, n) -> acc * n);
// identity 없이 사용하는 경우: Optional<T> 반환
int max = numbers.stream()
.reduce((a, b) -> a > b ? a : b)
.orElseThrow(NoSuchElementException::new);
min()/max()
최소값/최대값을 구한다.
이 때 반환 값은
Optional<T>로, 스트림 소스에 요소가 없을 경우 값이 존재하지 않을 수 있어 null-safe처리를 위함이다.
int min = numbers.stream()
.min(Integer::compareTo)
.orElseThrow(NoSuchElementException::new);
int max = numbers.stream()
.max(Integer::compareTo)
.orElseThrow(NoSuchElementException::new);
count()
요소 개수를 센다.
long evenCount = numbers.stream()
.filter(n -> n % 2 == 0)
.count();
anyMatch()/allMatch()/noneMatch()
boolean hasEven = numbers.stream().anyMatch(n -> n % 2 == 0);
boolean allPositive = numbers.stream().allMatch(n -> n > 0);
boolean noneNegative = numbers.stream().noneMatch(n -> n < 0);
findFirst()/findAny()
조건에 맞는 요소 하나를 반환한다.
이 때 반환 값은
Optional<T>로, 스트림 소스에 요소가 없을 경우 값이 존재하지 않을 수 있어 null-safe처리를 위함이다.
int firstEven = numbers.stream()
.filter(n -> n % 2 == 0)
.findFirst()
.orElseThrow(NoSuchElementException::new);
int anyEven = numbers.parallelStream()
.filter(n -> n % 2 == 0)
.findAny()
.orElseThrow(NoSuchElementException::new);
이 세 스트림은 int, long, double 타입 연산에 특화되어 있다.
Stream<T>를 쓰면 내부적으로 박싱/언박싱이 발생하는데, 기본형 스트림을 사용할 경우 이런 비용을 줄일 수 있다.
대표적으로 다음과 같은 방식으로 생성할 수 있다.
IntStream.of()를 이용해 직접 생성IntStream stream = IntStream.of(1, 2, 3);Arrays.stream(...) 사용int[] arr = {1, 2, 3, 4, 5};
IntStream stream = Arrays.stream(arr);IntStream range = IntStream.range(1, 5); // 1, 2, 3, 4
IntStream rangeClosed = IntStream.rangeClosed(1, 5); // 1, 2, 3, 4, 5사용되는 메서드는 Stream<T>와 상당히 유사하며, 여기에 합계, 평균, 통계를 위한 기본형 전용 메서드가 추가된 형태로 볼 수 있다.
이곳 예제에서도 설명의 편의를 위해 다음 배열을 기본으로 사용한다고 가정한다.
int[] numbers = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
Intermediate Operations
filter()/map()
기본형에 맞게 IntUnaryOperator, IntPredicate 등을 받는다는 점만 다르다.
IntStream evenStream = Arrays.stream(numbers)
.filter(n -> n % 2 == 0);
IntStream squared = Arrays.stream(numbers)
.map(n -> n * n); // 1, 4, 9, 16, 25
mapTo...()
기본형 스트림을 다른 타입의 스트림으로 바꿀 수 있다.
Stream<String> labels = Arrays.stream(numbers)
.mapToObj(n -> "num=" + n);
distinct(), sorted(), limit(), skip() 등
사용 방식은 Stream<T>와 동일하다.
int[] first3EvensAsc = Arrays.stream(numbers)
.filter(n -> n % 2 == 0)
.sorted()
.limit(3)
.toArray();
Terminal Operations
sum()
합계를 구할 때 쓰인다.
int sum = Arrays.stream(numbers)
.sum();
min()/max()/average()
최소값/최대값/평균값을 구한다.
int min = Arrays.stream(numbers)
.min()
.orElseThrow(NoSuchElementException::new);
long count = Arrays.stream(numbers)
.count();
double avg = Arrays.stream(numbers)
.average()
.orElse(0.0);
count()
요소 개수를 센다.
long evenCount = Arrays.stream(numbers)
.filter(n -> n % 2 == 0)
.count();
reduce()
요소들을 하나의 값으로 접는 데 사용한다.
// 1~10까지 제곱의 합: 1^2 + 2^2 + ... + 10^2
int sumOfSquares = Arrays.stream(numbers)
.reduce(0, (acc, n) -> acc + n * n);
// identity 없이 사용하면 OptionalInt 반환
int max = Arrays.stream(numbers)
.reduce((a, b) -> a > b ? a : b)
.orElseThrow(NoSuchElementException::new);
toArray()
기본형 배열로 변환한다.
int[] evens = Arrays.stream(numbers)
.filter(n -> n % 2 == 0)
.toArray();
summaryStatistics()
개수, 합, 최소값, 최대값, 평균을 한 번에 얻을 수 있다.
IntSummaryStatistics stats = Arrays.stream(numbers)
.summaryStatistics();
long count = stats.getCount();
int min = stats.getMin();
int max = stats.getMax();
long sum = stats.getSum();
double avg = stats.getAverage();
boxed() + collect(...)
기본형 스트림을 다시 Stream<Integer> 등으로 바꾼 뒤, Collectors와 함께 쓸 수도 있다.
List<Integer> list = Arrays.stream(numbers)
.boxed()
.collect(Collectors.toList());
parallelStream()은 멀티코어를 활용해 성능을 높일 수 있지만, 모든 경우에 적합한 것은 아니다.
forEachOrdered() 사용 필요// 나쁜 예: 상태 공유
List<Integer> result = new ArrayList<>();
numbers.parallelStream()
.forEach(n -> result.add(n)); // 스레드 안전하지 않음!
// 좋은 예: collect 사용
List<Integer> result = numbers.parallelStream()
.collect(Collectors.toList());
이로써 기본적인 Stream API와 그 사용방법에 대해서 다뤄 보았지만, 더 효율적으로 사용하기 위해서는 선언 순서나 단일/병렬 스트림 등 고려할 점이 많아 보인다.
단순 알고리즘 문제 해결을 위한 활용이 아닌 어플리케이션 코드로서의 Stream API 응용은 추후 기회가 된다면 작성해보고자 한다.