
이번에 나라뤼와 얘기를 하다가 알게된 사실이 있다.
어떤 사람들은 4번째 원추세포를 갖고 있어서, RGB 외의 추가적인 스펙트럼도 감지할 수 있는 능력이 있다는 사실이다.
그 사실을 알게되고나서 굉장히 흥미가 생긴 나는 이 이론을 이용해서 직접 신경망을 만들고 싶어졌다.
그로 인해 시작된 테트라크로매시(tetrachromacy) 프로젝트.
처음으로 논문을 먼저 찾아보니 실제 연구사례가 있는 것을 확인.
출처: arXiv (2024)
요약: 이 연구는 인간의 망막에 존재하는 원추세포의 수에 따라 색각 차원이 어떻게 자연스럽게 학습되는지를 시뮬레이션합니다. 특히, 4번째 원추세포를 추가하면 인공신경망이 4차원 색공간을 학습하여 테트라크로매시를 모사할 수 있음을 보여줍니다. 또한, 유전자 치료를 통해 색각이 향상된 원숭이 실험을 기반으로, 성인에서도 색각 차원의 향상이 가능함을 시뮬레이션으로 입증하였습니다.
출처: SpringerLink (2022)
요약: 이 논문은 인간의 삼색성(trichromacy) 색각 시스템을 모사하여 CNN(합성곱 신경망)의 구조를 개선하는 방법을 제안합니다. 색 자극에 대한 인간 뇌의 반응을 기반으로 네트워크 구조를 조정하여 분류 작업의 효율성을 4% 이상 향상시키고 계산 비용을 10% 감소시켰습니다. 이 방법은 AlexNet에 적용되었으며, 다른 복잡한 네트워크에도 확장 가능합니다
출처: SpringerLink (1995)
요약: 이 연구는 정상 색각과 이색성(dichromacy)을 가진 시각 시스템을 모델링한 신경망을 제안합니다. 세 가지 유형의 원추세포 손실을 각각 모델링하여, 스펙트럼 순수광에 대한 반응 특성과 색상 표현을 분석하였습니다. 이는 테트라크로매시와는 직접적인 관련은 없지만, 색각의 다양성을 신경망으로 모델링하는 기초를 제공합니다.
나는 이 논문들 중 제일 최신에 나온 2024년도 논문을 기준으로 신경망 개발을 진행하려고 한다.
저자: Atsunobu Kotani, Ren Ng (UC Berkeley)
초록: 이 논문은 인간의 색각이 어떻게 망막 신호로부터 뇌에서 자연스럽게 형성되는지를 설명하는 계산 모델을 제안합니다. 기존 연구들이 색각의 차원을 사전에 가정하는 반면, 이 연구는 시각 피질이 시신경 신호의 변동만을 통해 색각 차원을 추론할 수 있음을 주장합니다. 이를 위해 생물학적 눈의 시뮬레이션 엔진과 자기지도 학습 기반의 피질 학습 모델을 도입하여, N개의 원추세포가 있을 때 N차원의 색각이 자연스럽게 형성됨을 보여줍니다. 또한, 다람쥐 원숭이의 유전자 치료 사례를 시뮬레이션하여 인간의 색각을 3차원에서 4차원으로 향상시킬 가능성을 입증합니다.
기존 연구들은 색각 차원을 사전에 가정하지만, 이 논문은 시각 피질이 시신경 신호의 변동만을 통해 색각 차원을 추론할 수 있음을 주장합니다.
자연 이미지로부터 생성된 시신경 신호를 기반으로, 생물학적 눈의 시뮬레이션 엔진을 개발하여 다양한 원추세포 조합에 따른 색각 형성을 분석합니다.
자기지도 학습 원리를 기반으로 한 피질 학습 모델을 통해, 망막 신호로부터 색각이 자연스럽게 형성되는 과정을 시뮬레이션합니다.
다람쥐 원숭이의 유전자 치료 사례를 시뮬레이션하여, 인간의 색각을 3차원에서 4차원으로 향상시킬 수 있는 가능성을 제시합니다.
논문 결과를 본 이후 어떤 형태의 신경망을 사용해야될지 정했다.
| 요소 | 사용 이유 |
|---|---|
| Self-supervised Learning (자기지도 학습) | 외부 정답 없이 시신경 패턴만으로 색 분류를 학습하기 위함 |
| Biologically-inspired Vision Frontend | 생물학적으로 유사한 "원추세포 수용기" 시뮬레이션 |
| Linear + Nonlinear Layers | 시각 피질의 반응을 모사하여 색차원을 점진적으로 추출 |
| Contrastive Learning (SimCLR 류) | 입력 자극 간 상대적 차이를 기준으로 색각 차원 구성 |
| Dimensionality Emergence Test | 주성분분석(PCA) 또는 선형 분류기로 "자연 발생된 색차원" 확인 |
input image (natural scene) → retina simulation → cone signals
Cone Signals → Encoder Network (ResNet-18 또는 3-layer MLP) → Embedding Vector (latent space)
Loss = contrastive_distance(view_1, view_2)
이제 적합한 형태를 지키면서 진행해보도록 하겠다.
tetrachromacy_project
│
├── main.py # 학습 실행 진입점
├── model.py # TetrachromaticEncoder 정의
├── dataset.py # 4채널 CIFAR10 Dataset 정의
├── augmentations.py # SimCLR augmentation 정의
├── loss.py # Contrastive Loss 정의
├── visualize.py # PCA 시각화 코드
├── utils.py # 기타 도우미 함수
├── requirements.txt # 필요한 패키지 목록
└── saved_models
└── encoder_epoch10.pth # 학습된 모델 저장 위치
기존 RGB 이미지에 가상의 "X-cone" 채널을 추가하여 4채널 이미지로 만듭니다.
import torch
def add_x_cone(img_tensor):
r, g, b = img_tensor[0], img_tensor[1], img_tensor[2]
x_cone = (0.6 * r + 0.4 * g).unsqueeze(0)
return torch.cat([img_tensor, x_cone], dim=0) # [4, H, W]
4채널 이미지 입력을 받아서 128차원의 latent vector를 출력하는 신경망입니다.
import torch.nn as nn
import torch.nn.functional as F
class TetrachromaticEncoder(nn.Module):
def __init__(self, latent_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(4, 32, 3, padding=1), nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1))
)
self.projector = nn.Sequential(
nn.Flatten(),
nn.Linear(128, latent_dim)
)
def forward(self, x):
h = self.encoder(x)
z = self.projector(h)
return F.normalize(z, dim=1)
SimCLR에서 사용되는 contrastive loss입니다.
import torch
def contrastive_loss(z1, z2, temperature=0.5):
batch_size = z1.size(0)
z = torch.cat([z1, z2], dim=0) # [2B, D]
sim_matrix = torch.exp(torch.mm(z, z.t()) / temperature)
mask = ~torch.eye(2 * batch_size, dtype=bool).to(z.device)
sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)
positives = torch.exp((z1 * z2).sum(dim=-1) / temperature)
positives = torch.cat([positives, positives], dim=0)
loss = -torch.log(positives / sim_matrix.sum(dim=-1))
return loss.mean()
CIFAR-10을 로드하고, 각 이미지에 X-cone 채널을 추가한 후 return합니다.
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
from augmentations import SimCLRAugmentation
from utils import add_x_cone
class TetrachromaticCIFAR10(Dataset):
def __init__(self, train=True):
self.base_dataset = CIFAR10(root='./data', train=train, download=True)
self.transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
self.train = train
self.augmentation = SimCLRAugmentation()
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, idx):
img, label = self.base_dataset[idx]
if self.train:
img = self.transform(img) # RGB -> [3, H, W]
img_4ch = add_x_cone(img) # -> [4, H, W]
return img_4ch, label
else:
img = self.transform(img) # RGB -> [3, H, W]
img_4ch = add_x_cone(img) # -> [4, H, W]
return img_4ch, label
이미지에서 2개의 서로 다른 뷰(augmentation)를 생성합니다.
from torchvision import transforms
class SimCLRAugmentation:
def __init__(self):
self.base_transform = transforms.Compose([
transforms.RandomResizedCrop(128, scale=(0.5, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
])
def __call__(self, img):
return self.base_transform(img), self.base_transform(img)
모델 학습 루프. contrastive learning을 수행합니다.
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import os
from PIL import Image
from torchvision import transforms
from dataset import TetrachromaticCIFAR10
from model import TetrachromaticEncoder
from loss import contrastive_loss
from visualize import visualize_latent_space
from augmentations import SimCLRAugmentation
from utils import add_x_cone
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
def train_model():
dataset = TetrachromaticCIFAR10(train=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
augmentation = SimCLRAugmentation()
model = TetrachromaticEncoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)
save_dir = './saved_models'
os.makedirs(save_dir, exist_ok=True)
for epoch in range(10):
model.train()
total_loss = 0
for batch_idx, (x, labels) in enumerate(dataloader):
x = x.to(device)
batch_size = x.size(0)
x1_list, x2_list = [], []
for i in range(batch_size):
img = x[i].cpu().permute(1, 2, 0)[:, :, :3] # [H, W, 3] 형태로 변환 (RGB만)
img = img.numpy()
img = (img * 255).astype('uint8')
img = Image.fromarray(img)
img1, img2 = augmentation.base_transform(img), augmentation.base_transform(img)
img1 = transforms.ToTensor()(img1)
img2 = transforms.ToTensor()(img2)
img1_4ch = add_x_cone(img1)
img2_4ch = add_x_cone(img2)
x1_list.append(img1_4ch)
x2_list.append(img2_4ch)
x1 = torch.stack(x1_list).to(device)
x2 = torch.stack(x2_list).to(device)
z1, z2 = model(x1), model(x2)
loss = contrastive_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (batch_idx + 1) % 10 == 0:
print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}")
print(f"Epoch {epoch+1}, Avg Loss: {total_loss / len(dataloader):.4f}")
torch.save(model.state_dict(), os.path.join(save_dir, 'tetrachromatic_encoder.pth'))
print(f"Model saved to {os.path.join(save_dir, 'tetrachromatic_encoder.pth')}")
return model
def evaluate_model(model):
test_dataset = TetrachromaticCIFAR10(train=False)
visualize_latent_space(model, test_dataset, device)
if __name__ == "__main__":
trained_model = train_model()
evaluate_model(trained_model)
학습된 벡터를 시각화합니다. 클래스는 시각화 용도로만 사용합니다.
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import torch
def visualize_latent_space(model, dataset, device, num_samples=1000):
model.eval()
zs, labels = [], []
with torch.no_grad():
for i in range(num_samples):
img, label = dataset[i]
img = img.unsqueeze(0).to(device)
z = model(img).cpu().squeeze(0)
zs.append(z.numpy())
labels.append(label)
zs = np.stack(zs)
pca = PCA(n_components=2)
zs_2d = pca.fit_transform(zs)
plt.figure(figsize=(8, 6))
scatter = plt.scatter(zs_2d[:, 0], zs_2d[:, 1], c=labels, cmap='tab10', alpha=0.7)
plt.title("Latent Space (PCA Projection)")
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.legend(*scatter.legend_elements(), title="Class")
plt.show()
이렇게 코드적인 구현을 마무리했지만 아무래도 epochs 수가 굉장히 높기 때문에 나중에 실행하게되면 그때 다시 이미지로 넣도록 하겠다.