음악 분류 딥러닝을 만들자(21) - finetuning

응큼한포도·2024년 8월 16일
1

finetuner 구현

import torch
import torch.nn as nn
import torch.optim as optim

class Finetuner:
    """
    A class to finetune a neural network model for a specified number of steps.

    Attributes
    ----------
    model : torch.nn.Module
        The neural network model to be finetuned.
    data_loader : torch.utils.data.DataLoader
        The DataLoader for the dataset used for finetuning.
    device : str
        The device ('cuda' or 'cpu') to run the finetuning process on.
    optimizer : torch.optim.Optimizer
        The optimizer used for finetuning.
    criterion : torch.nn.Module
        The loss function used for finetuning.

    Methods
    -------
    finetune(steps):
        Finetunes the model for the specified number of steps.
    """

    def __init__(self, model, data_loader, device='cuda', lr=0.001, weight_decay=0.0005):
        """
        Initializes the Finetuner with a given model, data loader, and other parameters.

        Parameters
        ----------
        model : torch.nn.Module
            The neural network model to be finetuned.
        data_loader : torch.utils.data.DataLoader
            The DataLoader for the dataset used for finetuning.
        device : str, optional
            The device to run the finetuning process on (default is 'cuda').
        lr : float, optional
            The learning rate for the optimizer (default is 0.001).
        weight_decay : float, optional
            The weight decay (regularization) for the optimizer (default is 0.0005).
        """
        self.model = model
        self.data_loader = data_loader
        self.device = device
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.criterion = nn.CrossEntropyLoss()  # Change according to your problem

    def finetune(self, steps):
        """
        Finetunes the model for a given number of steps.

        Parameters
        ----------
        steps : int
            The number of steps to finetune the model.

        Returns
        -------
        None
        """
        self.model.to(self.device)
        self.model.train()

        for _ in range(steps):
            for inputs, targets in self.data_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                loss.backward()
                self.optimizer.step()

파인튜닝 별 거 없고 모두가 아는 그 과정을 통해 구현했다. 옵티마이저는 gpt가 추천하는 adamw를 사용했다. adamw는 나중에 학습할 예정이다.

pruning 수정

import sys, os

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch.nn as nn
import torch
from resource_measurement import ResourceMeasurement
from evaluations.metrics import Evaluation
from fine_tuning import Finetuner  # Import the Finetuner class


class FilterPruner:
    """
    A class to prune filters from convolutional layers in a neural network model.

    Attributes
    ----------
    model : torch.nn.Module
        The neural network model containing the layers to be pruned.
    device : str
        The device ('cuda' or 'cpu') to run the pruning process on.
    resource_measurer : ResourceMeasurement
        A utility to measure the resource usage (e.g., latency) of the model.
    evaluator : Evaluation
        An instance of the Evaluation class to evaluate model accuracy.
    finetuner : Finetuner
        An instance of the Finetuner class for finetuning the model.
    baseline_accuracy : float
        The baseline accuracy of the model before pruning.
    finetune_steps : int
        The number of finetuning steps to perform on each pruned model.
    """

    def __init__(self, model, validation_loader, device='cuda', finetune_steps=10, weight_decay=0.0005):
        """
        Initializes the FilterPruner with a given model, validation data loader, and finetuning parameters.

        Parameters
        ----------
        model : torch.nn.Module
            The neural network model containing the layers to be pruned.
        validation_loader : torch.utils.data.DataLoader
            The DataLoader for the validation dataset used to evaluate model accuracy.
        device : str, optional
            The device to run the pruning process on (default is 'cuda').
        finetune_steps : int, optional
            The number of finetuning steps to perform on each pruned model (default is 10).
        weight_decay : float, optional
            The weight decay (regularization) for the optimizer (default is 0.0005).
        """
        self.model = model
        self.device = device
        self.resource_measurer = ResourceMeasurement(metric='latency')
        self.evaluator = Evaluation(validation_loader, device)
        self.finetuner = Finetuner(model, validation_loader, device, lr=0.001, weight_decay=weight_decay)
        self.baseline_accuracy = None
        self.finetune_steps = finetune_steps

    def prune_layerwise(self,
                        input_tensor,
                        target_metric_reduction=0.05,
                        w=-0.15,
                        target_latency=None,
                        reward_threshold=1.0):
        """
        Prunes filters from the model's convolutional layers one layer at a time, evaluating performance after each step.

        Parameters
        ----------
        input_tensor : torch.Tensor
            The input tensor used to measure the model's performance (e.g., latency).
        target_metric_reduction : float, optional
            The target reduction in the performance metric (default is 0.05 for 5%).
        w : float, optional
            The weight factor used in the reward calculation (default is -0.15).
        target_latency : float, optional
            The target latency to achieve. If None, it's set to the initial latency (default is None).
        reward_threshold : float, optional
            The threshold for the reward value to decide when to stop pruning (default is 1.0).

        Returns
        -------
        None
        """
        initial_metric = self.resource_measurer.measure(self.model, input_tensor)
        self.baseline_accuracy = self.evaluator.evaluation_accuracy(self.model)

        if target_latency is None:
            target_latency = initial_metric

        for name, layer in self.model.named_modules():
            if isinstance(layer, nn.Conv2d):
                print(f"Evaluating layer {name}")

                num_filters = layer.out_channels
                best_proposal = None
                best_tradeoff = -float('inf')  # Initialize with negative infinity

                for filters_to_prune in range(1, num_filters):
                    pruned_layer = self.prune_filter(name, filters_to_prune)

                    # Finetune the model with the pruned layer
                    self.finetuner.finetune(self.finetune_steps)

                    new_metric = self.resource_measurer.measure(self.model, input_tensor)
                    new_accuracy = self.evaluator.evaluation_accuracy(self.model)

                    latency_change = (initial_metric - new_metric) / initial_metric
                    accuracy_change = (self.baseline_accuracy - new_accuracy) / self.baseline_accuracy

                    tradeoff = (accuracy_change / latency_change) if latency_change != 0 else -float('inf')

                    if tradeoff > best_tradeoff and latency_change > target_metric_reduction:
                        best_tradeoff = tradeoff
                        best_proposal = (name, pruned_layer, new_metric, new_accuracy)

                if best_proposal:
                    name, best_layer, new_metric, new_accuracy = best_proposal
                    self.replace_layer(self.model, name, best_layer)
                    initial_metric = new_metric
                    self.baseline_accuracy = new_accuracy

    def prune_filter(self, layer_name, num_filters_to_prune):
        """
        Prunes filters from the specified convolutional layer and replaces it with a new layer.

        Parameters
        ----------
        layer_name : str
            The name of the convolutional layer in the model that will be pruned.
        num_filters_to_prune : int
            The number of filters to remove from the specified layer.

        Returns
        -------
        torch.nn.Conv2d
            The new convolutional layer with the remaining filters.
        """
        layer = dict(self.model.named_modules())[layer_name]
        keep_filters = layer.out_channels - num_filters_to_prune
        filter_norms = torch.norm(layer.weight.data, p=2, dim=[1, 2, 3])
        indices = torch.argsort(filter_norms)[num_filters_to_prune:]

        new_layer = nn.Conv2d(
            in_channels=layer.in_channels,
            out_channels=keep_filters,
            kernel_size=layer.kernel_size,
            stride=layer.stride,
            padding=layer.padding,
            bias=layer.bias is not None
        )

        new_layer.weight.data = layer.weight.data[indices].clone()
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data[indices].clone()

        self.model._modules[layer_name] = new_layer  # Update the model with the new layer

        if layer_name + 1 < len(self.model._modules):
            next_layer_name = list(self.model._modules.keys())[layer_name + 1]
            next_layer = self.model._modules[next_layer_name]
            if isinstance(next_layer, nn.Conv2d):
                self._update_next_layer(next_layer, keep_filters)

        return new_layer

    def _update_next_layer(self, next_layer, keep_filters):
        """
        Adjusts the next convolutional layer to match the pruned output channels.

        Parameters
        ----------
        next_layer : torch.nn.Conv2d
            The convolutional layer immediately following the pruned layer.
        keep_filters : int
            The number of output filters kept from the pruned layer.
        """
        new_next_layer = nn.Conv2d(
            in_channels=keep_filters,
            out_channels=next_layer.out_channels,
            kernel_size=next_layer.kernel_size,
            stride=next_layer.stride,
            padding=next_layer.padding,
            bias=next_layer.bias is not None
        )

        new_next_layer.weight.data = next_layer.weight.data[:, :keep_filters].clone()
        if next_layer.bias is not None:
            new_next_layer.bias.data = next_layer.bias.data.clone()

        self.model._modules[list(self.model._modules.keys())[self.model._modules.index(next_layer)]] = new_next_layer

    def replace_layer(self, model, layer_name, new_layer):
        """
        Replaces a layer in the model with a new layer.

        Parameters
        ----------
        model : torch.nn.Module
            The model in which the layer is to be replaced.
        layer_name : str
            The name of the layer to be replaced.
        new_layer : torch.nn.Module
            The new layer that will replace the old layer.
        """
        model._modules[layer_name] = new_layer

논문 레시피에 따르면 proposal 마다 파인튜닝을 진행하고 있다. layerwise에서 pruning을 진행하고 바로 뒤에 파인튜닝과정을 추가하였다

profile
미친 취준생

0개의 댓글