Dropout의 학습(drop neurons) 및 추론(rescale)에 대한 부분 최적화.
rescale operation을 학습 단으로 옮기고 추론시에는 Dropout레이어를 완전삭제하도록 연구중(정확히는 rescale = 1)
학습시 Dropout 시간(랜덤함수 소요비용 상당함) 감소 성능 테스트.
(Fast dropout training 논문은 정규화 기법이지 실제 Dropout이 아님. 따라서 고려하지 않음)
5가지 방법으로 테스트 진행하였으며 MKL의 베르누이 분포 생성 함수와 약간의 트릭을 사용하여 Naive 구현 대비 6~10배의 성능향상을 가져옴.
std::vector<float> Dropout1(const std::vector<float>& input, float prob) {
std::vector<float> output(input.size(), 0);
for (size_t i = 0; i < input.size(); i++) {
if (static_cast<float>(rand()) / RAND_MAX > prob)
output[i] = input[i];
}
return output;
}
std::vector<float> Dropout2(const std::vector<float>& input, float prob) {
std::vector<float> output(input.size(), 0);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> d(0.f, 1.f);
for (size_t i = 0; i < input.size(); i++) {
if (d(gen) > prob)
output[i] = input[i];
}
return output;
}
std::vector<float> Dropout3(const std::vector<float>& input, float prob) {
std::vector<float> output(input.size(), 0);
std::random_device rd;
std::mt19937 gen(rd());
std::bernoulli_distribution d(prob);
for (size_t i = 0; i < input.size(); i++) {
if (!d(gen))
output[i] = input[i];
}
return output;
}
uint32_t xor128(void) {
static uint32_t x = 123456789u;
static uint32_t y = 362436069u;
static uint32_t z = 521288629u;
static uint32_t w = 88675123u;
uint32_t t;
t = x ^ (x << 11);
x = y; y = z; z = w;
return w = w ^ (w >> 19) ^ (t ^ (t >> 8));
}
std::vector<float> Dropout4(const std::vector<float>& input, float prob) {
std::vector<float> output(input.size(), 0.0f);
for (int i = 0; i < input.size(); i++) {
if (static_cast<float>(xor128()) / UINT_MAX > prob)
output[i] = input[i];
}
return output;
}
std::vector<float> Dropout5(const std::vector<float>& input, float prob) {
std::vector<float> output(input.size(), 0.f);
VSLStreamStatePtr stream;
vslNewStream(&stream, VSL_BRNG_MCG31, 2025);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, input.size(), reinterpret_cast<int*>(output.data()), 1. - prob);
vsCeil(input.size(), output.data(), output.data());
vsMul(input.size(), input.data(), output.data(), output.data());
return output;
}
결과
5번을 제외한 나머지 방법은 확률안정성이 떨어지고 5번은 매우 정확한 확률안정성을 보여준다.