정답을 모르는 추론 데이터가 배치(예: 512개) 단위로 들어오면 다음의 사이클이 반복됨.
GalaxyDataset은 TTA 과정에서 정답(Label) 없이 피처(X)와 도메인 지식(P_NOMERGER)만 반환하도록 구성함.import os
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# (경로 설정 및 하이퍼파라미터 생략 - 기존 코드와 동일)
FTAT_CONFIG = {
"batch_size" : 512,
"learning_rates": [1e-3, 1e-4, 5e-4, 1e-5],
"alpha" : 0.10,
"epsilon_p" : 0.70,
"beta" : 0.30,
}
class GalaxyDataset(Dataset):
"""레이블 없이 피처와 도메인 확률(P_NOMERGER)을 반환하는 TTA 전용 데이터셋"""
def __init__(self, X, p_nomerger):
self.X = torch.tensor(X, dtype=torch.float32)
self.p_nomerger = torch.tensor(p_nomerger, dtype=torch.float32)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.p_nomerger[idx]
configure_for_tta 함수를 통해 전체 파라미터는 동결(Freeze)하고, 데이터의 분포를 맞추는 데 핵심적인 역할을 하는 LayerNorm의 파라미터만 학습되도록 해제함.def configure_for_tta(model: nn.Module) -> nn.Module:
"""모든 파라미터를 동결하고 LayerNorm의 파라미터만 업데이트 허용"""
model.train()
for param in model.parameters():
param.requires_grad_(False)
for module in model.modules():
if isinstance(module, nn.LayerNorm):
module.train()
for param in module.parameters():
param.requires_grad_(True)
return model
def tta_params(model: nn.Module):
return [p for p in model.parameters() if p.requires_grad]
P_hat)를 시간적 EMA(지수 이동 평균) 방식으로 갱신함.def batch_entropy(probs: torch.Tensor) -> torch.Tensor:
return -(probs * (probs.clamp(min=1e-8)).log()).sum(dim=1)
class ConfidentDistributionOptimizer:
def __init__(self, num_classes, P0, alpha, epsilon_p):
self.K = num_classes
self.P0 = P0.clone()
self.P_hat = P0.clone()
self.alpha = alpha
self.epsilon = -(epsilon_p * np.log(epsilon_p) + (1 - epsilon_p) * np.log(1 - epsilon_p))
def update(self, probs: torch.Tensor) -> None:
ent = batch_entropy(probs)
mask = ent < self.epsilon
if mask.sum() == 0: return
conf = probs[mask]
P_tilde = conf.mean(dim=0)
C_hat = torch.zeros(self.K, self.K, device=probs.device)
pred_cls = probs.argmax(dim=1)
for k in range(self.K):
km = pred_cls == k
C_hat[k] = probs[km].mean(dim=0) if km.sum() > 0 else torch.eye(self.K, device=probs.device)[k]
try: unbiased = torch.linalg.solve(C_hat, P_tilde)
except Exception: unbiased = P_tilde
self.P_hat = (1 - self.alpha) * self.P_hat + self.alpha * unbiased
self.P_hat = torch.softmax(self.P_hat, dim=0)
def adjust(self, probs: torch.Tensor) -> torch.Tensor:
ratio = self.P_hat / (self.P0.to(probs.device) + 1e-8)
adjusted = probs * ratio.unsqueeze(0)
return adjusted / (adjusted.sum(dim=1, keepdim=True) + 1e-8)
class LocalConsistentWeighter:
def __init__(self, beta: float = 0.30):
self.beta = beta
def compute_weights(self, X: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
N = X.shape[0]
diff = X.unsqueeze(1) - X.unsqueeze(0)
dist_mat = diff.pow(2).sum(dim=2).clamp(min=0).sqrt()
triu_idx = torch.triu_indices(N, N, offset=1)
dist_thresh = dist_mat[triu_idx[0], triu_idx[1]].mean()
margin = probs.max(dim=1).values - probs.min(dim=1).values
weights = torch.zeros(N, device=X.device)
for k in range(N):
nbr_mask = (dist_mat[k] < dist_thresh)
nbr_mask[k] = False
if nbr_mask.sum() == 0:
weights[k] = margin[k]
continue
nbr_mean = probs[nbr_mask].mean(dim=0)
consistency = torch.norm(probs[k] - nbr_mean)
is_consistent = (consistency < self.beta).float()
weights[k] = margin[k] * is_consistent
return weights
learning_rates)을 가진 기본 모델들을 유지하며 동시에 업데이트함.class DynamicModelEnsembler:
def __init__(self, source_model, learning_rates):
self.M = len(learning_rates)
self.models, self.optims = [], []
for lr in learning_rates:
m = configure_for_tta(copy.deepcopy(source_model))
opt = optim.AdamW(tta_params(m), lr=lr, weight_decay=1e-4)
self.models.append(m)
self.optims.append(opt)
def _model_weights(self, X):
losses = []
with torch.no_grad():
for m in self.models:
m.eval()
losses.append(batch_entropy(torch.softmax(m(X), dim=1)).mean().item())
m.train()
losses = np.array(losses, dtype=np.float64)
span = losses.max() - losses.min()
norm = (losses - losses.min()) / span if span > 1e-8 else np.zeros(self.M)
w = 1.0 - norm
return w / (w.sum() + 1e-8)
def step(self, X, cdo, lcw):
mw = self._model_weights(X)
self.models[0].eval()
with torch.no_grad():
proxy_probs = torch.softmax(self.models[0](X), dim=1)
self.models[0].train()
sample_w = lcw.compute_weights(X.detach(), proxy_probs.detach())
cdo.update(proxy_probs.detach())
all_adj_probs = []
for m, opt in zip(self.models, self.optims):
m.train()
probs = torch.softmax(m(X), dim=1)
adj_probs = cdo.adjust(probs)
ent = batch_entropy(adj_probs)
loss = (sample_w * ent).sum() / sample_w.sum().clamp(min=1e-8)
opt.zero_grad()
loss.backward()
opt.step()
m.eval()
with torch.no_grad():
all_adj_probs.append(cdo.adjust(torch.softmax(m(X), dim=1)))
ensemble = torch.zeros_like(all_adj_probs[0])
for w, p in zip(mw, all_adj_probs):
ensemble += w * p
return ensemble, mw



conf)에만 의존하지 않고, 물리적 도메인 지식(P_NOMERGER)을 필터링 조건으로 함께 사용하여 논리적으로 타당한 샘플만 학습에 사용하도록 로직을 수정함.기존 코드에서 CDO의 라벨 분포 갱신 로직과 DME의 모델 업데이트 로직에 도메인 지식 필터를 추가함.
P_NOMERGER 값이 논리적 기준(Non-merger는 높게, Pre/Post는 낮게)에 부합하는 샘플만 라벨 분포(P_hat) 추정에 사용함. # ConfidentDistributionOptimizer 모듈 내부
def update(self, probs: torch.Tensor, p_nom_batch: torch.Tensor) -> None:
ent = batch_entropy(probs)
conf_mask = ent < self.epsilon # 1차: 모델 확신도 조건
pred_cls = probs.argmax(dim=1)
logical_mask = torch.zeros_like(conf_mask, dtype=torch.bool)
is_non_merger = (pred_cls == 0)
is_merger = (pred_cls != 0)
# 2차: 도메인 지식(P_NOMERGER) 조건
logical_mask[is_non_merger] = p_nom_batch[is_non_merger] >= self.th_high
logical_mask[is_merger] = p_nom_batch[is_merger] <= self.th_low
final_mask = conf_mask & logical_mask # 이중 필터링 교집합
if final_mask.sum() == 0: return
conf = probs[final_mask]
# ... 이하 기존 C_hat 및 P_hat 갱신 로직 동일
masked_sample_w)시켜, 모델이 꼼수로 학습하는 것을 방어함. # DynamicModelEnsembler 모듈 내부 step 함수
def step(self, X: torch.Tensor, p_nom_batch: torch.Tensor, cdo, lcw):
# ... (proxy_probs 추출 및 lcw, cdo.update 적용)
is_non_merger = (proxy_pred_cls == 0)
is_merger = (proxy_pred_cls != 0)
# Loss 계산을 위한 도메인 논리 마스크 생성
loss_logical_mask = torch.zeros_like(proxy_pred_cls, dtype=torch.bool)
loss_logical_mask[is_non_merger] = p_nom_batch[is_non_merger] >= cdo.th_high
loss_logical_mask[is_merger] = p_nom_batch[is_merger] <= cdo.th_low
# LCW 가중치에 논리 마스크를 씌워 논리에 안 맞는 샘플 가중치를 0으로 만듦
masked_sample_w = sample_w * loss_logical_mask.float()
# ... (이하 모델별 루프)
ent = batch_entropy(adj_probs)
denom = masked_sample_w.sum().clamp(min=1e-8)
loss = (masked_sample_w * ent).sum() / denom # 마스킹된 샘플로만 학습
if masked_sample_w.sum() > 0:
opt.zero_grad()
loss.backward()
opt.step()
# ...

위와 같이 도메인 지식(P_NOMERGER)을 직접 Loss 필터링에 개입시키는 것은 결과적으로 모델의 분류 성능을 크게 높일 수 있으나, 연구 논리를 전개할 때 데이터 누수(Data Leakage) 또는 오라클(Oracle) 개입의 비판을 받을 소지가 있음. 따라서 연구의 '최종 목적'에 따라 두 가지 방향 중 하나로 논리를 디펜스해야 함.
방향성 1: 순수 알고리즘 중심 (Unsupervised TTA)
P_NOMERGER를 활용한 Loss 마스킹을 제거해야 함.방향성 2: 도메인 문제 해결 중심 (Physics-Informed / Weakly-Supervised)
P_NOMERGER)를 약지도 학습(Weakly-Supervised Learning)의 가이드라인으로 주입하여 모델을 올바른 물리적 공간으로 견인한 '물리 지식 융합 모델링'으로 디펜스함.