스프링 ,flask AWS OpenSearch 를 활용한 유사 이미지 검색

이진우·2024년 11월 27일
0

스프링 학습

목록 보기
46/46

도입 계기

졸업 프로젝트 cre8 에서
구인자가 원하는 분야의 구직자를 조금 더 편리하게 찾을 수 있도록

구인자가 원하는 이미지를 넣으면
그에 대한 응답값으로 유사도가 높은 이미지를 가진 포트폴리오를 추천하는 기능을 만드는 것이 회의 결과 도출 되었다.

도입 기술

CNN

CNN 이란 Convolutional Neural Network 의 약자로 딥러닝에서 주로 이미지나 영상 데이터를 처리할 때 쓰이며 Convolution 이라는 전처리 작업이 들어간다.

일반적으로 입력층과 은닉층, 출력층으로 구성된 DNN 은 2차원 형태의 이미지 역시 flatten 시켜 한 줄의 데이터로 만드는데 이 과정에서 지역적인 정보가 손상이 된다.

사람의 귀는 어디인지 , 이런 부분에 대한 정보가 없다는 것이다.

따라서 한 개의 벡터로 Neural Network 를 구성하기 보다
합성곱 연산을 사용한다.

이러한 지역적인 세부 특성을 잡아내기 위해서 합성곱 연산을 사용하는데

아래 그림과 같은 형식이다.

위 Filter 를 거치는 방식으로 이미지의 로컬리티를 살린 특성을 추출할 수 있다.

이러한 정보를 polling 을 거쳐 사이즈를 줄인다.

이런 식으로 학습하고 머신 러닝에서 모델의 예측값과 실제 레이블의 차이를 계산하는 손실함수를 줄이는 방향으로 filter 의 크기 등이 조정된다.

VGG 16

Layer 를 16개의 층을 가졌기 때문에 VGG16 이라고 명명받았다.
CNN을 기반으로 설계되었고
Fully Connected Layer (완전 연결층의 출력값) 을 벡터로 사용하여 이미지의 특징을 고차원 벡터화할 수 있다.

코사인 유사도

이렇게 이미지의 특징을 벡터 즉 숫자로 표현하였다.
그렇다면 유사한 이미지는 어떻게 생각할 수 있을까?

단순한 방법은 벡터간의 거리를 계산하는 것이다.

이를 유클리드 방법이라고 부른는데

이러한 방법은 전체적인 형태나 유사도 보다

부분적인 부분에 집중할 가능성이 커진다.

(얼굴에 눈코입이 비슷하게 달려있으면, 비슷하다고 생각해야 하는데 얼굴의 크기나 이런 부분이 영향이 가면 안되기 때문이다)

따라서 코사인 유사도라는 방법을 사용한다.

크기에 대한 부분은 상관이 없으므로

방향 벡터를 구하기 위해서 (내적 / 크기 곱) 을 활용한다.

1에 가까울 수록 유사도가 큰 것을 의미하고 -1 일 수록 반대 방향이다.

OpenSearch

이러한 방법을 적용하려면 각 포트폴리오의 이미지 마다
벡터를 저장하고 보관해야 하는데

mongoDB를 활용한다 가정하면
유사한 이미지를 찾을 때 결국 서버 메모리에 모든 포트폴리오 이미지의 벡터를 가져온 이후 사용자가 원하는 이미지와 각각 코사인 유사도를 계산해야 하므로 비효율적인 부분이 많다고 생각했다.

따라서 이러한 부분을 제공해주는 ElasticSearch 를 활용하려 했으나 ,

EC2 프리티어의 메모리 부족으로 OpenSearch 프리티어를 사용하여 문제를 해결한다.

결론

CNN 을 기반으로 한 잘 훈련된 모델인 VGG16 을 사용하여 이미지의 특징 벡터를 추출하고, 이미지 유사도를 측정하기 위해 코사인 유사도를 활용하고 이러한 벡터는 OpenSearch 에 저장한다.

OpenSearch 에는

PUT /portfolio
{
  "settings": {
    "index.knn": true,
    "index.knn.space_type": "cosinesimil"
  },
  "mappings": {
    "properties": {
      "vector": {
        "type": "knn_vector",
        "dimension": 512
      },
      "portfolioId": {
        "type": "keyword"
      },
      "accessUrl": {
        "type": "text"
      },
      "portfolioImageId": {
        "type": "keyword"
      }
    }
  }
}

위와 같은 타입을 지정해주고 cosinesimil 을 통해서 특정 점에서 가장 가까운 벡터를 찾을 때 코사인 유사도를 기준으로 가지고 오라고 설정하여 준다 .

Flask & Python 코드

@app.route('/find_similar_image', methods=['POST'])
def find_similar_image():
    if 'query_image_file' in request.files:
        query_image = request.files['query_image_file']  # multipart file
    elif 'query_image_url' in request.form:
        query_image = request.form['query_image_url']  # URL (텍스트 데이터)

    similar_images = VectorElasticSearch.find_similar_image(VGGVector.extract_features(query_image))

    response_list = []

    for most_similar_mongo_id, similarity_score in similar_images:
        query_result = VectorElasticSearch.find_by_portfolio_image_id(most_similar_mongo_id)

        response_list.append({
            "most_similar_portfolio_id": query_result['portfolioId'],
            "most_similar_portfolio_image_id": query_result['portfolioImageId'],
            'most_similar_access_url': query_result['accessUrl'],
            "similarity_score": json.dumps(str(round(similarity_score, 4)))
        })

    # 최종 결과 반환
    return jsonify(response_list)
def find_similar_image(query_vector):
    cosine_query = {
        "size": 5,
        "query": {
            "knn": {
                "vector": {
                    "vector": query_vector,
                    "k": 5
                }
            }
        }
    }

    response = es.search(index=index_name, body=cosine_query)

    similar_images = []

    # 결과 처리
    for hit in response['hits']['hits']:
        most_similar_mongo_id = hit['_source']['portfolioImageId']
        similarity_score = hit['_score']

        similar_images.append((most_similar_mongo_id, similarity_score))

    return similar_images

위와 같은 코드로 상위 5개의 유사도를 가진 portfolioImageId 를 반환한다.

스프링

데이터 전달.

@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class PortfolioRecommendService {

    private final WebClient webClient;

    private static final String QUERY_IMAGE_URL = "query_image_url";
    private static final String QUERY_IMAGE_FILE = "query_image_file";

    private static final String ML_RECOMMEND_API="/find_similar_image";

    private static final String SAVE_VECTOR = "/portfolio/vector";

    public List<PortfolioRecommendResponseDto> showRecommendPortfolio(final PortfolioRecommendRequestDto portfolioRecommendRequestDto){


        List<PortfolioAIRequestDto> portfolioAIRequestDtoList = webClient.post()
                .uri(ML_RECOMMEND_API)
                .body(getBody(portfolioRecommendRequestDto))
                .retrieve()
                .bodyToMono(new ParameterizedTypeReference<List<PortfolioAIRequestDto>>() {})
                .block();

        return portfolioAIRequestDtoList.stream()
                .map(portfolioAIRequestDto -> PortfolioRecommendResponseDto.builder()
                        .id(portfolioAIRequestDto.getMost_similar_portfolio_id())
                        .accessUrl(portfolioAIRequestDto.getMost_similar_access_url())
                        .similarity(portfolioAIRequestDto.getSimilarity_score())
                        .build())
                .toList();

    }

    public void savePortfolioWithVector(final Long portfolioImageId){

        webClient.post()
                .uri(uriBuilder -> uriBuilder
                        .path(SAVE_VECTOR)
                        .queryParam("portfolioImageId", portfolioImageId)
                        .build())
                .retrieve()
                .bodyToMono(Void.class)
                .block();
    }

    private MultipartInserter getBody(final PortfolioRecommendRequestDto portfolioRecommendRequestDto) {

        if (!multiPartFileBlank(portfolioRecommendRequestDto.getImageFile())) {

            return BodyInserters.fromMultipartData(QUERY_IMAGE_FILE, portfolioRecommendRequestDto.getImageFile().getResource());
        }

        if(imageUrlBlank(portfolioRecommendRequestDto.getImageUrl())){
            throw new BadRequestException(ErrorCode.CANT_ALL_BLANK_FILE_URL);
        }

        return BodyInserters.fromMultipartData(QUERY_IMAGE_URL, portfolioRecommendRequestDto.getImageUrl());
    }

    private boolean multiPartFileBlank(final MultipartFile multipartFile){
        return multipartFile==null || multipartFile.isEmpty();
    }

    private boolean imageUrlBlank(final String imageUrl){
        return imageUrl==null || imageUrl.isBlank();
    }

}

python 서버와 통신하기 위해서 WebClient 를 사용하고 그대로 사용자에게 반환하여 주는 간단한 코드이다 .

벡터 저장

포트폴리오 이미지를 생성하려면 포트폴리오를 update 해야 한다.

Update 코드

이미지를 저장하고 삭제할 때 현재 S3 를 사용하고 있다.
다만 우려 되는 부분은
포트폴리오를 수정할 떄 이미지 뿐만 아니라 다른 것들도 함께 수정이 가능하기 때문에
@Transactional 이 끝날 때 변경 감지가 동작하며 이때 sql 의 제약조건과 위배되어 롤백된다면 이미지가 실제로는 DB에 없는데 s3 나 opensearch 에 저장이 되어있는 상태가 되버린다.

따라서 @Transactional 에서 sql 제약조건의 위배 없이 성공적으로 동작할 떄

즉 commit 시 event 를 발행하여 후처리를 해주도록 아래와 같이 코드를 작성했다.

private void updatePortfolioImage(final List<MultipartFile> multipartFileList,final Portfolio portfolio,final List<Long> deletePortfolioImageId){


        List<ImageDeleteEventDto> imageDeleteEventDtos = new ArrayList<>();
        List<String> newAccessUrlList = new ArrayList<>();
        List<ImageSaveEventDto> imageSaveEventDtos = new ArrayList<>();


        if(deletePortfolioImageId!=null){

            deletePortfolioImageId.forEach(portfolioImageId->{
                imageDeleteEventDtos.add(
                        new ImageDeleteEventDto(portfolioImageId, portfolioImageRepository.findById(portfolioImageId).orElseThrow(()->new NotFoundException(ErrorCode.CANT_FIND_PORTFOLIO_IMAGE_ID)).getAccessUrl()));
                portfolioImageRepository.deleteById(portfolioImageId);
            });

        }


        if(multipartFileList!=null){

            multipartFileList.stream().forEach(multipartFile -> {

                String accessUrl = s3ImageService.saveImage(multipartFile,portFolioImage,multipartFile.getOriginalFilename());
                newAccessUrlList.add(accessUrl);

                PortfolioImage portfolioImage = PortfolioImage.builder()
                        .originalName(multipartFile.getOriginalFilename())
                        .portfolio(portfolio)
                        .accessUrl(accessUrl)
                        .build();

                portfolioImageRepository.save(portfolioImage);
                imageSaveEventDtos.add(new ImageSaveEventDto(portfolioImage.getId()));

            });

        }


        eventPublisher.publishEvent(S3UploadImageListRollbackEvent.builder().newAccessImageUrlList(newAccessUrlList).build());
        eventPublisher.publishEvent(UploadImageListCommitDeleteEvent.builder().imageDeleteEventDtos(imageDeleteEventDtos).build());
        eventPublisher.publishEvent(UploadImageListCommitSaveEvent.builder().imageSaveEventDtos(imageSaveEventDtos).build());

    }

Event 생성 시 처리 코드

@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
    @Async("threadPoolTaskExecutor")
    public void transactionalDeleteEventListenerAfterCommit(final UploadImageListCommitDeleteEvent uploadImageListCommitDeleteEvent) {

        uploadImageListCommitDeleteEvent.getImageDeleteEventDtos().forEach(imageDeleteEventDto->{
            s3ImageService.deleteImage(imageDeleteEventDto.deleteAccessUrl());
            portfolioImageDocumentService.delete(imageDeleteEventDto.deletePortfolioImageId());
        });
    }

    @TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
    @Async("threadPoolTaskExecutor")
    public void transactionalSaveEventListenerAfterCommit(final UploadImageListCommitSaveEvent uploadImageListCommitSaveEvent){

        uploadImageListCommitSaveEvent.getImageSaveEventDtos().forEach(imageSaveEventDto -> {
            portfolioRecommendService.savePortfolioWithVector(imageSaveEventDto.savePortfolioImageId());
        });
    }

벡터를 사용하여 계산한다거나 , S3에 이미지를 저장하거나 삭제하는 것은 상대적으로 시간이 오래 걸리는 작업이기 때문에 비동기 처리를 하여 별도의 쓰레드에서 로직을 처리할 수 있게 하였다 .

Python 서버에 벡터 저장 요청

public void savePortfolioWithVector(final Long portfolioImageId){

        webClient.post()
                .uri(uriBuilder -> uriBuilder
                        .path(SAVE_VECTOR)
                        .queryParam("portfolioImageId", portfolioImageId)
                        .build())
                .retrieve()
                .bodyToMono(Void.class)
                .block();
    }
@app.route('/portfolio/vector', methods=['POST'])
def save_vector():

    portfolio_image_id = request.args.get('portfolioImageId')

    if portfolio_image_id:

        portfolio_image = PortFolioImageRepository.get_portfolio_image_one(portfolio_image_id)
        try:

            vector = VGGVector.extract_features(portfolio_image[1])
            document = {
                "portfolio_image_id": portfolio_image[0],
                "access_url": portfolio_image[1],
                "portfolio_id": portfolio_image[2],
                "vector": vector.tolist()
            }

            VectorElasticSearch.insert_portfolio_image_with_vector_elastic(document)
        except Exception as e:
            print("저장시 오류 발생 URL {access_url}: {e}")

        return jsonify({"message": "Received", "portfolioImageId": portfolio_image_id}), 200
    else:
        # 파라미터가 없을 경우 에러 응답
        return jsonify({"error": "portfolioImageId is required"}), 400

이후 파이썬 서버에서 이 응답을 받아 요청을 처리한다.

시연

profile
기록을 통해 실력을 쌓아가자

0개의 댓글