[Codeforces Round 127 (Div. 1)] - Guess That Car! (삼분 탐색, 수학, C++, Python)

SHSHSH·2023년 7월 5일

CODEFORCES

목록 보기
17/26

Codeforces Round 127 (Div. 1) - Guess That Car! 링크
(2023.07.05 기준 Difficulty *1800)

문제

길이가 4m인 정사각형 모양의 칸이 n * m 모양의 격자가 있다.
각 칸마다 자동차가 있으며 자동차의 희귀도가 각각 주어진다.
격자 선의 교차점을 아무 곳이나 선택하여 모든 자동차를 추측해야 하는데, 각 자동차마다 추측하는 시간은 (유클리드 거리의 제곱 * 자동차의 희귀도)이다. 자동차의 기준점은 각 칸의 중심이라고 하였을 때, 모든 자동차를 추측하는 시간이 가장 빠른 교차점 출력

알고리즘

식을 전개하여 삼분 탐색

풀이

식을 전개하면 위와 같이 된다.
결국, x에 대한 값과 y에 대한 값은 독립적이란 것을 알 수 있다. 그러므로 x에 대한 값과 y에 대한 값을 따로 구해주면 된다.

식은 거리의 제곱 * 희귀도이므로 기준이 되는 교차점이 중간으로 들어갈수록 값이 낮아짐을 직관적으로 알 수 있다. 이는 곧 볼록 함수라는 것이고 결국 삼분 탐색을 이용하면 된다.

코드

  • C++
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<ll, int> pli;

int c[1000][1000], n, m;

// 교차점 좌표는 (i * 4, j * 4)
// 자동차의 중심 좌표는 (i * 4 + 2, j * 4 + 2)

ll f(int x){ // x에 대한 값
    x *= 4;
    ll result = 0;
    for (int i = 0; i < n; i++){
        ll d = (i * 4 + 2 - x) * (i * 4 + 2 - x);
        for (int j = 0; j < m; j++) result += d * c[i][j];
    }
    return result;
}

pli ternary_x(){ // x에 대한 삼분 탐색
    int st = 0, en = n + 1;
    while (en - st >= 3){
        int mid_1 = (st * 2 + en) / 3;
        ll res_1 = f(mid_1);
        int mid_2 = (st + en * 2) / 3;
        ll res_2 = f(mid_2);

        if (res_1 <= res_2) en = mid_2;
        else st = mid_1;
    }

    pli result = {f(st), st};
    for (int x = st + 1; x <= en; x++) result = min(result, {f(x), x});
    return result;
}

ll g(int y){ // y에 대한 값
    y *= 4;
    ll result = 0;
    for (int j = 0; j < m; j++){
        ll d = (j * 4 + 2 - y) * (j * 4 + 2 - y);
        for (int i = 0; i < n; i++) result += d * c[i][j];
    }
    return result;
}

pli ternary_y(){ // y에 대한 삼분 탐색
    int st = 0, en = m + 1;
    while (en - st >= 3){
        int mid_1 = (st * 2 + en) / 3;
        ll res_1 = g(mid_1);
        int mid_2 = (st + en * 2) / 3;
        ll res_2 = g(mid_2);

        if (res_1 <= res_2) en = mid_2;
        else st = mid_1;
    }

    pli result = {g(st), st};
    for (int y = st + 1; y <= en; y++) result = min(result, {g(y), y});
    return result;
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> m;
    for (int i = 0; i < n; i++) for (int j = 0; j < m; j++) cin >> c[i][j];

    /**
    선택한 교차점이 (x, y)면 sum(c[i][j] * ((x - xi) ** 2 + (y - yi) ** 2))
    이는 sum(c[i][j] * (x - xi) ** 2) + sum(c[i][j] * (y - yi) ** 2) 로 분리가 가능하다.
    결국 x와 y에 대한 값은 독립적이므로 x와 y에 대한 최적 값은 따로 삼분 탐색으로 구하자.
    **/

    auto [d1, x] = ternary_x();
    auto [d2, y] = ternary_y();
    cout << d1 + d2 << '\n' << x << ' ' << y;
}
  • Python (PyPy3)
import sys; input = sys.stdin.readline

# 교차점 좌표는 (i * 4, j * 4)
# 자동차의 중심 좌표는 (i * 4 + 2, j * 4 + 2)

def f(x): # x에 대한 값
    x *= 4
    result = 0
    for i in range(n):
        d = (i * 4 + 2 - x) ** 2
        for j in range(m):
            result += c[i][j] * d
    return result

def ternary_x(): # x에 대한 삼분 탐색
    st = 0; en = n + 1
    while en - st >= 3:
        mid_1 = (st * 2 + en) // 3
        res_1 = f(mid_1)
        mid_2 = (st + en * 2) // 3
        res_2 = f(mid_2)
        
        if res_1 <= res_2:
            en = mid_2
        else:
            st = mid_1

    return min((f(x), x) for x in range(st, en + 1))

def g(y): # y에 대한 값
    y *= 4
    result = 0
    for j in range(m):
        d = (j * 4 + 2 - y) ** 2
        for i in range(n):
            result += c[i][j] * d
    return result

def ternary_y(): # y에 대한 삼분 탐색
    st = 0; en = m + 1
    while en - st >= 3:
        mid_1 = (st * 2 + en) // 3
        res_1 = g(mid_1)
        mid_2 = (st + en * 2) // 3
        res_2 = g(mid_2)
        
        if res_1 <= res_2:
            en = mid_2
        else:
            st = mid_1

    return min((g(y), y) for y in range(st, en + 1))

n, m = map(int, input().split())
c = [list(map(int, input().split())) for _ in range(n)]

'''
선택한 교차점이 (x, y)면 sum(c[i][j] * ((x - xi) ** 2 + (y - yi) ** 2))
이는 sum(c[i][j] * (x - xi) ** 2) + sum(c[i][j] * (y - yi) ** 2) 로 분리가 가능하다.
결국 x와 y에 대한 값은 독립적이므로 x와 y에 대한 최적 값은 따로 삼분 탐색으로 구하자.
'''

d1, x = ternary_x()
d2, y = ternary_y()
print(d1 + d2)
print(x, y)
profile
GNU 16 statistics & computer science

0개의 댓글