인공지능 오픈소스

윤수환·2024년 6월 3일

인공지능

목록 보기
4/10
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
//Pytorch의 torchvision 패키지에서 몇 가지 사전 훈련된 컨볼루션 신경망 아키텍처를 가져옴
//사용되는 아키텍처 Resnet, Vgg, Inception / torch, numpy모듈
from torchvision.models import resnet18, resnet50, resnet101, resnet152, vgg16, vgg19, inception_v3
import torch
import torch.nn as nn
import random
import numpy as np

//이미지 특성을 추출하기 위한 CNN인코더 모델 정의(사전학습된 RetNet아키텍처 기반)
//nn.Module을 상속 - Pytorch모든 사용자 정의 모델이 따라야하는 규칙
class EncoderCNN(nn.Module):
	//인코더 모델 초기화
    //embed_size - 출력 임베딩의 크기 / dropout - 드롭아웃 확률 지정 / img_model - 사용할 이미지 모델 지정 /pretrained - 사전 훈련된 모델을 사용할지 여부 결정
    def __init__(self, embed_size, dropout=0.5, image_model='resnet101', pretrained=True):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        //globals()[image_model]을 사용하여 지정된 이름의 모델을 동적으로 가져옵니다.
        resnet = globals()[image_model](pretrained=pretrained)
        modules = list(resnet.children())[:-2]  # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)

        self.linear = nn.Sequential(nn.Conv2d(resnet.fc.in_features, embed_size, kernel_size=1, padding=0),
                                    nn.Dropout2d(dropout))

    def forward(self, images, keep_cnn_gradients=False):
        """Extract feature vectors from input images."""

        if keep_cnn_gradients:
            raw_conv_feats = self.resnet(images)
        else:
            with torch.no_grad():
                raw_conv_feats = self.resnet(images)
        features = self.linear(raw_conv_feats)
        features = features.view(features.size(0), features.size(1), -1)

        return features


class EncoderLabels(nn.Module):
    def __init__(self, embed_size, num_classes, dropout=0.5, embed_weights=None, scale_grad=False):

        super(EncoderLabels, self).__init__()
        embeddinglayer = nn.Embedding(num_classes, embed_size, padding_idx=num_classes-1, scale_grad_by_freq=scale_grad)
        if embed_weights is not None:
            embeddinglayer.weight.data.copy_(embed_weights)
        self.pad_value = num_classes - 1
        self.linear = embeddinglayer
        self.dropout = dropout
        self.embed_size = embed_size

    def forward(self, x, onehot_flag=False):

        if onehot_flag:
            embeddings = torch.matmul(x, self.linear.weight)
        else:
            embeddings = self.linear(x)

        embeddings = nn.functional.dropout(embeddings, p=self.dropout, training=self.training)
        embeddings = embeddings.permute(0, 2, 1).contiguous()

        return embeddings

0개의 댓글