Java

[Java] Stream API

이덩우 2023. 12. 21. 19:06

Stream API의 등장

Java8 이전에 자바는 객체지향 언어이기 때문에 함수형 프로그래밍을 적용할 수 없었다.

하지만 Java8부터 함수형 인터페이스, 람다 표현식, Stream API가 등장하면서 함수형 프로그래밍을 할 수 있는 다양한 API를 제공하기 시작했다.

그 중 Stream API는 배열, Collection 등의 *데이터 자체를 추상화*하고 How보다는 What에 초점을 두어 데이터를 처리하는 다양한 함수를 정의해두었다.

 


동작 흐름 및 주요 특징

Stream API는 데이터를 처리하는 다양한 연산을 지원한다. 크게 아래와 같이 세 단계로 나눌 수 있다.

  1. 스트림 생성
  2. 중간 연산
  3. 최종 연산

동작 흐름, 출처 : tcpschool

주요 특징을 알아보자.

  • 스트림은 원본 데이터를 변경하지 않고, 최초 원본 데이터를 복사한 스트림을 생성한다.
  • 스트림은 재사용이 불가능하다.
  • 스트림은 *지연 연산(Lazy)*을 통해 성능을 최적화한다.
    • 스트림은 중간 연산인 map, filter 등을 순차적으로 처리할 것 같지만, 실제로는 수직적으로 실행한다.
    • 이 과정에서 앞 단의 조건에 부합하지 않는다면 뒷 단의 연산 자체를 수행하지 않아 성능을 최적화 할 수 있다.
  • 중간 연산은 Stream 형태로 반환하기 때문에 연속적으로 연결하여 사용할 수 있다.
  • parallelStream() 메소드를 통해 병렬 처리가 가능한 스트림을 생성할 수 있다.

 


Stream 생성

Stream은 컬렉션이나 배열을 보다 보다 직관적으로 다루기위해 등장했다.

생성하는 법은 아래와 같다.

Integer[] intArr = new Integer[]{1, 2, 3, 4, 5, 10};
List<Integer> intList = Arrays.asList(intArr);

Stream<Integer> arrStream = Arrays.stream(intArr);
Stream<Integer> listStream = intList.stream();
  • 배열 : Arrays.stream()
  • 컬렉션 : 객체명.stream()

위 예제에서는 일반적인 Stream을 생성했다.

배열을 다룰 때 primitive 자료형에 대해서는 Stream<T>이 아닌 IntStream, DoubleStream 등 숫자 연산에 최적화된 스트림이 자동으로 생성된다.

일반적인 Stream<T>에서 평균값, 최댓값, 최솟값 등 연산을 위해서는 매핑을 통해 IntStream, DoubleStream 등으로 변경해야하는데, primitive 자료형 배열로 스트림을 생성할 수 있다면 매핑 단계를 건너뛸 수 있다.

아래는 IntStream, DoubleStream을 생성하는 예시이다.

int[] intArr = new int[]{1, 2, 3, 4, 10};
double[] doubleArr = new double[]{1.0, 2.0, 3.0};

IntStream intStream = Arrays.stream(intArr);
DoubleStream doubleStream = Arrays.stream(doubleArr);

 

 


중간 연산 (가공)

스트림은 조건에 따른 필터링, 매핑, 요소 건너뛰기, 정렬 등 다양한 중간 연산을 위한 메소드를 제공한다.

자주 다루는 중간 연산자에 대해 살펴보자

  • filter(Predicate<? super T> predicate) : 스트림에서 주어진 조건에 맞는 요소만으로 구성된 새로운 스트림을 반환한다.
  • map(Function<? super T, ? extends R> mapper) : 스트림의 요소들을 함수의 인자로 전달해, 함수의 반환값으로 이루어진 새로운 스트림을 반환한다.
    • mapToInt(), mapToLong(), mapToDouble()IntStream, LongStream, DoubleStream을 반환한다.
    • flatMap() : 2차원 요소를 1차원 요소로 평탄화할 때 사용한다.
  • distinct() : 중복 요소를 제거한 뒤 새로운 스트림을 반환한다.
  • skip(long n) : 스트림의 첫 번째 요소부터 인자로 전달된 n개를 제외한 나머지 요소만으로 이루어진 새로운 스트림을 반환한다.
  • limit(long n) : skip이 n개를 건너뛰는 것이라면, limit은 첫 번째 요소부터 n개로 이루어진 새로운 스트림을 반환한다.
  • sorted(Comparator<? super T> comparator) : 스트림 내 요소를 Comparator를 이용해 정렬한다. 인자를 전달하지 않으면 natural order로 정렬한다.

 


최종 연산 (결과 만들기)

중간 연산을 통해 변환된 Stream은 최종 연산을 통해 각 요소를 소모하여 결과를 표시한다.

즉, 지연되었던 모든 연산이 최종 연산 시에 모두 수행된다.

최종 연산을 통해 요소를 모두 소모하면 해당 스트림은 더 이상 사용할 수 없다.

자주 사용하는 최종 연산자에 대해 알아보자.

  • void forEach(Consumer<? super T> action) : 스트림의 각 요소를 소모해 특정 동작을 수행한다.
  • Optional<T> reduce(BinaryOperator<T> accumulator) : 처음 두 요소를 가지고 특정 연산을 수행한 뒤, 그 결과와 다음 요소를 가지고 연산을 반복한다. 만약 인자로 초깃값을 첫 번째 자리에 넣어준다면 Optional이 빠진 타입을 반환한다.
  • collect(Collector<? super T, A, R> collector) : 인자로 전달되는 Collectors 객체에 구현된 방법대로 스트림의 요소를 수집한다.
    Collectors에는 List, Set, joining(String) 등 다양한 방식이 정의되어있다.
  • 요소 검색 관련
    • Optional<T> findFirst() : 해당 스트림에서 첫 번째 요소를 참조하는 Optional 객체를 반환한다.
    • Optional<T> findAny() : findFirst()와 동일, 병렬 스트림일때 주로 사용한다.
  • 요소 검사 관련
    • boolean anyMatch(Predicate<? super T> predicate) : 스트림 요소 중 하나라도 조건을 만족하면 true를 반환한다.
    • boolean allMatch(Predicate<? super T> predicate) : 스트림 요소 전부 조건을 만족하면 true를 반환한다.
    • boolean noneMatch(Predicate<? super T> predicate) : 스트림 요소 전부 조건을 만족하지 않으면 true를 반환한다.
  • 통계 관련
    • long count() : 스트림 요소의 개수를 반환한다.
    • Optional<T> max(Comparator<? super T> comparator) : 가장 큰 값을 가지는 요소를 참조하는 Optional 객체를 반환한다.
    • Optional<T> min(Comparator<? super T> comparator) : 가장 작은 값을 가지는 요소를 참조하는 Optional 객체를 반환한다.
  • 연산 관련 (IntStream, DoubleStream 등으로 변환 후 사용 가능)
    • sum() : 스트림 내 모든 요소의 합을 구하여 반환한다.
    • average() : 스트림 내 모든 요소의 평균값을 구하여 반환한다.

 


다양한 예제

public class Practice1 {
    public static void main(String[] args) {
        // 주어진 문자열 리스트에서 소문자로 이뤄진 단어들을 대문자로 변환한 뒤, 중복을 제거하고 정렬된 리스트를 반환해야 합니다.
        List<String> words = Arrays.asList("apple", "Banana", "ORANGE", "apple", "grapes", "banana");
        List<String> collect = words.stream()
                .map(s -> s.toUpperCase())
                .distinct()
                .sorted()
                .collect(Collectors.toList());
        System.out.println(collect);
    }
}

public class Practice2 {
    public static void main(String[] args) {
        // 주어진 숫자 리스트에서 짝수를 선택하고 그 제곱을 오름차순으로 정렬한 리스트 만들기
        List<Integer> numbers = Arrays.asList(5, 3, 8, 2, 9, 4, 7, 6);
        List<Integer> collect = numbers.stream()
                .filter(n -> n % 2 == 0)
                .map(n -> n * n)
                .sorted()
                .collect(Collectors.toList());

        System.out.println(collect);
    }
}

public class Practice3 {
    public static void main(String[] args) {
        // 주어진 숫자들의 평균값 구하기
        List<Integer> numbers = Arrays.asList(3, 1, 7, 5, 9, 2, 8);
        double v = numbers.stream()
                .mapToInt(num -> num)
                .average()
                .orElse(-1);
        System.out.println(v);
    }
}


public class Practice4 {
    public static void main(String[] args) {
        // 주어진 나라를 글자 수 별로 그룹화하기
        List<String> countries = Arrays.asList("Japan", "Korea", "USA", "China", "Canada");

        Map<Integer, List<String>> collect = countries.stream()
                .collect(Collectors.groupingBy(c -> c.length()));
        collect.forEach((k, v) -> System.out.println(k + ": " + v.toString()));
    }
}

public class PracticeFinal {
    public static void main(String[] args) {
        // 주어진 사람들의 리스트에서 나이가 30세 이상인 사람들을 성별로 그룹화하고,
        // 각 그룹에서는 그들의 이름으로 이루어진 리스트를 만들어야 합니다. 이를 나이 내림차순으로 정렬한 후 출력해야 합니다.
        List<Person> people = Arrays.asList(
                new Person("Alice", 28, "Female"),
                new Person("Bob", 35, "Male"),
                new Person("Charlie", 30, "Male"),
                new Person("Diana", 25, "Female"),
                new Person("Eve", 32, "Female")
        );
        Map<String, List<String>> collect = people.stream()
                .filter(p -> p.getAge() >= 30)
                .collect(Collectors.groupingBy(p -> p.getGender(),
                        Collectors.mapping(p -> p.getName(), Collectors.toList())));
        collect.forEach((k, v) -> v.sort(Comparator.reverseOrder()));
        System.out.println(collect);
    }
}