Tetrachromacy Project

seongyun·2025년 5월 18일

Neural Network

목록 보기
6/8

이번에 나라뤼와 얘기를 하다가 알게된 사실이 있다.
어떤 사람들은 4번째 원추세포를 갖고 있어서, RGB 외의 추가적인 스펙트럼도 감지할 수 있는 능력이 있다는 사실이다.
그 사실을 알게되고나서 굉장히 흥미가 생긴 나는 이 이론을 이용해서 직접 신경망을 만들고 싶어졌다.
그로 인해 시작된 테트라크로매시(tetrachromacy) 프로젝트.

처음으로 논문을 먼저 찾아보니 실제 연구사례가 있는 것을 확인.

주요 연구 논문 소개

1. A Computational Framework for Modeling Emergence of Color Vision in the Human Brain

출처: arXiv (2024)

요약: 이 연구는 인간의 망막에 존재하는 원추세포의 수에 따라 색각 차원이 어떻게 자연스럽게 학습되는지를 시뮬레이션합니다. 특히, 4번째 원추세포를 추가하면 인공신경망이 4차원 색공간을 학습하여 테트라크로매시를 모사할 수 있음을 보여줍니다. 또한, 유전자 치료를 통해 색각이 향상된 원숭이 실험을 기반으로, 성인에서도 색각 차원의 향상이 가능함을 시뮬레이션으로 입증하였습니다.

2. V4α: A New Method for CNNs Inspired by Trichromacy Perception

출처: SpringerLink (2022)

요약: 이 논문은 인간의 삼색성(trichromacy) 색각 시스템을 모사하여 CNN(합성곱 신경망)의 구조를 개선하는 방법을 제안합니다. 색 자극에 대한 인간 뇌의 반응을 기반으로 네트워크 구조를 조정하여 분류 작업의 효율성을 4% 이상 향상시키고 계산 비용을 10% 감소시켰습니다. 이 방법은 AlexNet에 적용되었으며, 다른 복잡한 네트워크에도 확장 가능합니다

3. Neural Network Models for Normal and Dichromatic Color Vision

출처: SpringerLink (1995)

요약: 이 연구는 정상 색각과 이색성(dichromacy)을 가진 시각 시스템을 모델링한 신경망을 제안합니다. 세 가지 유형의 원추세포 손실을 각각 모델링하여, 스펙트럼 순수광에 대한 반응 특성과 색상 표현을 분석하였습니다. 이는 테트라크로매시와는 직접적인 관련은 없지만, 색각의 다양성을 신경망으로 모델링하는 기초를 제공합니다.

나는 이 논문들 중 제일 최신에 나온 2024년도 논문을 기준으로 신경망 개발을 진행하려고 한다.

논문 개요

제목: A Computational Framework for Modeling Emergence of Color Vision in the Human Brain

저자: Atsunobu Kotani, Ren Ng (UC Berkeley)

초록: 이 논문은 인간의 색각이 어떻게 망막 신호로부터 뇌에서 자연스럽게 형성되는지를 설명하는 계산 모델을 제안합니다. 기존 연구들이 색각의 차원을 사전에 가정하는 반면, 이 연구는 시각 피질이 시신경 신호의 변동만을 통해 색각 차원을 추론할 수 있음을 주장합니다. 이를 위해 생물학적 눈의 시뮬레이션 엔진과 자기지도 학습 기반의 피질 학습 모델을 도입하여, N개의 원추세포가 있을 때 N차원의 색각이 자연스럽게 형성됨을 보여줍니다. 또한, 다람쥐 원숭이의 유전자 치료 사례를 시뮬레이션하여 인간의 색각을 3차원에서 4차원으로 향상시킬 가능성을 입증합니다.

주요 내용 및 기여

1. 시각 피질의 색각 차원 추론 능력

기존 연구들은 색각 차원을 사전에 가정하지만, 이 논문은 시각 피질이 시신경 신호의 변동만을 통해 색각 차원을 추론할 수 있음을 주장합니다.

2. 생물학적 눈의 시뮬레이션 엔진 개발

자연 이미지로부터 생성된 시신경 신호를 기반으로, 생물학적 눈의 시뮬레이션 엔진을 개발하여 다양한 원추세포 조합에 따른 색각 형성을 분석합니다.

3. 자기지도 학습 기반의 피질 학습 모델

자기지도 학습 원리를 기반으로 한 피질 학습 모델을 통해, 망막 신호로부터 색각이 자연스럽게 형성되는 과정을 시뮬레이션합니다.

4. 색각 차원의 향상 시뮬레이션

다람쥐 원숭이의 유전자 치료 사례를 시뮬레이션하여, 인간의 색각을 3차원에서 4차원으로 향상시킬 수 있는 가능성을 제시합니다.

논문 결과를 본 이후 어떤 형태의 신경망을 사용해야될지 정했다.

핵심 요건 요약

요소사용 이유
Self-supervised Learning (자기지도 학습)외부 정답 없이 시신경 패턴만으로 색 분류를 학습하기 위함
Biologically-inspired Vision Frontend생물학적으로 유사한 "원추세포 수용기" 시뮬레이션
Linear + Nonlinear Layers시각 피질의 반응을 모사하여 색차원을 점진적으로 추출
Contrastive Learning (SimCLR 류)입력 자극 간 상대적 차이를 기준으로 색각 차원 구성
Dimensionality Emergence Test주성분분석(PCA) 또는 선형 분류기로 "자연 발생된 색차원" 확인

신경망 구조 (요약)

1. Retina Simulation (입력 전처리)

  • 생물학적 망막처럼 작동
  • RGB 대신 n개의 원추세포 채널 (ex. 4채널)
  • 색수용체의 민감도 곡선 기반 weight 적용

input image (natural scene) → retina simulation → cone signals

2. Visual Cortex-like Encoder (CNN or MLP)

  • CNN 또는 MLP 계열 encoder
  • 자기지도 학습 방식으로 입력 cone signal에서 latent representation 추출

Cone Signals → Encoder Network (ResNet-18 또는 3-layer MLP) → Embedding Vector (latent space)

  • 학습 목표는 "같은 장면에서 유사한 색 공간 유지 + 다른 장면 구분"

3. Contrastive Loss (SimCLR, BYOL 등)

  • Augmented view들 사이의 색 정보를 정렬하는 방식으로 학습
  • 색을 분류하지 않고 색 공간 구조를 스스로 "형성"하도록 유도

Loss = contrastive_distance(view_1, view_2)

4. 색각 차원 추출 및 평가

  • 학습 후 Latent Vector에 대해 PCA 수행 → 주요 색각 차원 시각화
  • Linear Probe 또는 K-means 등으로 색 구분 성능 확인

적합한 형태

  • 시각 정보 전용 Encoder (CNN/MLP)
  • Contrastive Self-Supervised Framework (SimCLR, MoCo, BYOL, Barlow Twins 등)
  • 입력단에 생물학 기반 전처리 계층 (n-채널 "cone array" 기반)
  • Latent space 분석 가능한 구조

적합하지 않은 형태

  • 단순 MSE 기반 회귀
  • fully supervised classification (색 이름 레이블 주입)
  • colorization GAN (목적이 다름)

이제 적합한 형태를 지키면서 진행해보도록 하겠다.

전체 구조 요약

  1. 데이터 전처리
    -일반 RGB 이미지 → 가상의 4번째 채널 추가 → 4채널 이미지 생성
  2. 인공망막 시뮬레이션
    -4개의 cone sensitivity를 가정 (예: L, M, S, X-cone)
  3. Encoder 모델 정의 (CNN 또는 MLP)
  4. Projection Head + Contrastive 학습 손실 적용
  5. Latent space 분석 (PCA 등)

디렉토리 형태

tetrachromacy_project
│
├── main.py                  # 학습 실행 진입점
├── model.py                 # TetrachromaticEncoder 정의
├── dataset.py               # 4채널 CIFAR10 Dataset 정의
├── augmentations.py         # SimCLR augmentation 정의
├── loss.py                  # Contrastive Loss 정의
├── visualize.py             # PCA 시각화 코드
├── utils.py                 # 기타 도우미 함수
├── requirements.txt         # 필요한 패키지 목록
└── saved_models
    └── encoder_epoch10.pth  # 학습된 모델 저장 위치

Tetrachromacy 구현 흐름

1. 4채널 이미지 생성 함수

기존 RGB 이미지에 가상의 "X-cone" 채널을 추가하여 4채널 이미지로 만듭니다.

import torch

def add_x_cone(img_tensor):
    r, g, b = img_tensor[0], img_tensor[1], img_tensor[2]
    x_cone = (0.6 * r + 0.4 * g).unsqueeze(0)
    return torch.cat([img_tensor, x_cone], dim=0)  # [4, H, W]

2. TetrachromaticEncoder 모델 정의

4채널 이미지 입력을 받아서 128차원의 latent vector를 출력하는 신경망입니다.

import torch.nn as nn
import torch.nn.functional as F

class TetrachromaticEncoder(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.projector = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return F.normalize(z, dim=1)

3. Contrastive Loss 정의 (NT-Xent)

SimCLR에서 사용되는 contrastive loss입니다.

import torch

def contrastive_loss(z1, z2, temperature=0.5):
    batch_size = z1.size(0)
    z = torch.cat([z1, z2], dim=0)  # [2B, D]
    sim_matrix = torch.exp(torch.mm(z, z.t()) / temperature)
    mask = ~torch.eye(2 * batch_size, dtype=bool).to(z.device)
    sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)

    positives = torch.exp((z1 * z2).sum(dim=-1) / temperature)
    positives = torch.cat([positives, positives], dim=0)

    loss = -torch.log(positives / sim_matrix.sum(dim=-1))
    return loss.mean()

4. 4채널 Dataset 클래스 정의

CIFAR-10을 로드하고, 각 이미지에 X-cone 채널을 추가한 후 return합니다.

from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
from augmentations import SimCLRAugmentation
from utils import add_x_cone

class TetrachromaticCIFAR10(Dataset):
    def __init__(self, train=True):
        self.base_dataset = CIFAR10(root='./data', train=train, download=True)
        self.transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor()
        ])
        self.train = train
        self.augmentation = SimCLRAugmentation()

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        img, label = self.base_dataset[idx]
        
        if self.train:
            img = self.transform(img)  # RGB -> [3, H, W]
            img_4ch = add_x_cone(img)  # -> [4, H, W]
            return img_4ch, label
        else:
            img = self.transform(img)  # RGB -> [3, H, W]
            img_4ch = add_x_cone(img)  # -> [4, H, W]
            return img_4ch, label

5. SimCLR용 Augmentation (두 view 생성)

이미지에서 2개의 서로 다른 뷰(augmentation)를 생성합니다.

from torchvision import transforms

class SimCLRAugmentation:
    def __init__(self):
        self.base_transform = transforms.Compose([
            transforms.RandomResizedCrop(128, scale=(0.5, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
        ])

    def __call__(self, img):
        return self.base_transform(img), self.base_transform(img)

6. 학습 루프 구성

모델 학습 루프. contrastive learning을 수행합니다.

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import os
from PIL import Image
from torchvision import transforms

from dataset import TetrachromaticCIFAR10
from model import TetrachromaticEncoder
from loss import contrastive_loss
from visualize import visualize_latent_space
from augmentations import SimCLRAugmentation
from utils import add_x_cone

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def train_model():
    dataset = TetrachromaticCIFAR10(train=True)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    augmentation = SimCLRAugmentation()
    
    model = TetrachromaticEncoder().to(device)
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    
    save_dir = './saved_models'
    os.makedirs(save_dir, exist_ok=True)
    
    for epoch in range(10):
        model.train()
        total_loss = 0
        for batch_idx, (x, labels) in enumerate(dataloader):
            x = x.to(device)
            batch_size = x.size(0)
            x1_list, x2_list = [], []
            
            for i in range(batch_size):
                img = x[i].cpu().permute(1, 2, 0)[:, :, :3]  # [H, W, 3] 형태로 변환 (RGB만)
                img = img.numpy()
                img = (img * 255).astype('uint8')
                img = Image.fromarray(img)
                
                img1, img2 = augmentation.base_transform(img), augmentation.base_transform(img)
                
                img1 = transforms.ToTensor()(img1)
                img2 = transforms.ToTensor()(img2)
                
                img1_4ch = add_x_cone(img1)
                img2_4ch = add_x_cone(img2)
                
                x1_list.append(img1_4ch)
                x2_list.append(img2_4ch)
            
            x1 = torch.stack(x1_list).to(device)
            x2 = torch.stack(x2_list).to(device)
            
            z1, z2 = model(x1), model(x2)
            loss = contrastive_loss(z1, z2)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}")
                
        print(f"Epoch {epoch+1}, Avg Loss: {total_loss / len(dataloader):.4f}")
    
    torch.save(model.state_dict(), os.path.join(save_dir, 'tetrachromatic_encoder.pth'))
    print(f"Model saved to {os.path.join(save_dir, 'tetrachromatic_encoder.pth')}")
    
    return model

def evaluate_model(model):
    test_dataset = TetrachromaticCIFAR10(train=False)
    
    visualize_latent_space(model, test_dataset, device)

if __name__ == "__main__":
    trained_model = train_model()
    
    evaluate_model(trained_model)

7. Latent Space 시각화 (PCA)

학습된 벡터를 시각화합니다. 클래스는 시각화 용도로만 사용합니다.

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_latent_space(model, dataset, device, num_samples=1000):
    model.eval()
    zs, labels = [], []

    with torch.no_grad():
        for i in range(num_samples):
            img, label = dataset[i]
            img = img.unsqueeze(0).to(device)
            z = model(img).cpu().squeeze(0)
            zs.append(z.numpy())
            labels.append(label)

    zs = np.stack(zs)
    pca = PCA(n_components=2)
    zs_2d = pca.fit_transform(zs)

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(zs_2d[:, 0], zs_2d[:, 1], c=labels, cmap='tab10', alpha=0.7)
    plt.title("Latent Space (PCA Projection)")
    plt.xlabel("PC 1")
    plt.ylabel("PC 2")
    plt.legend(*scatter.legend_elements(), title="Class")
    plt.show()

이렇게 코드적인 구현을 마무리했지만 아무래도 epochs 수가 굉장히 높기 때문에 나중에 실행하게되면 그때 다시 이미지로 넣도록 하겠다.

0개의 댓글