얼마 전 GPT의 실수 비교 방식이 화제가 된 적이 있었다.
질문) "3.9와 3.11 중에 뭐가 더 커?" / 답변) "3.11이 더 큽니다."
수학 시간에 졸지 않은 사람들은 3.9가 3.11보다 크다고 생각하지만, GPT의 눈으로 보면 Python 3.9와 Python 3.11 중 후자를 더 크게 보는 학습 데이터가 많아 저렇게 생각할 수 있다. GPT의 세상에서 3.1은 3보다 크고, 마찬가지로 3.9는 3.2보다 크지만, 3.10은 3.2보다 큰 값으로 처리된다.
구체적으로, 소수점을 기준으로 왼쪽을 수로 읽은 값을 x, 오른쪽을 수로 읽은 값을 y라고 할 때 두 수의 비교가 다음과 같이 이루어진다:
x값이 더 작으면 더 작은 수이다.
x값이 같을 경우 y값이 더 작으면 더 작은 수이다.
소수점이 없는 경우는 같은 수의 소수점이 있는 경우보다 항상 작게 취급된다. (다시 말해, GPT에게 3은 3.0보다 작다.)
N개의 수들이 주어졌을 때, 이를 GPT의 기준에 따라 비내림차순으로 정렬해보자.
[문제 제약 조건]
[조건 1] N은 1 이상 1,000 이하이다.
[조건 2] 각 수는 실수 혹은 정수로 표현되고, 0 이상 100 이하이며, 소수점이 없거나 소수점 아래 최대 3자리까지 주어진다.
[서브 태스크별 제약 조건]
별도의 서브 태스크가 존재하지 않습니다.
첫 번째 줄에 N이 주어진다.
두 번째 줄부터 N개의 줄에 걸쳐, 각 수가 한 줄에 하나씩 주어진다.
01.23, 3.001과 같이 소수점을 기준으로 각 파트의 0이 아닌 수 이전에 leading zero가 나오는 경우는 주어지지 않는다.
추가로, 3.00, 3.000, 또는 00.1과 같이 소수점을 기준으로 한 파트에 두 개 이상의 0만 주어지는 입력은 없다.
첫 번째 줄부터 N개의 줄에 걸쳐, 입력으로 주어진 수들을 GPT의 기준으로 비내림차순으로 정렬한 순서대로 한 줄에 하나씩 출력한다.
5
1.2
1.11
2.9
4.2
3
1.2
1.11
2.9
3
4.2
import sys
'''
그냥 정렬 문제 아님?
'''
def main():
N, *lst = sys.stdin.read().split('\n')
lst = list(filter(None, lst))
lst.sort(key=lambda x: [int(x.split(".")[0]), len(x.split(".")[1]) if len(x.split(".")) > 1 else 0,
int(x.split(".")[1]) if len(x.split(".")) > 1 else -1])
print(*lst, sep="\n")
if __name__ == '__main__':
main()
split 연산 중복 제거
parts = x.split(".")
if len(parts) == 1:
int_part = int(parts[0])
key_val = [int_part, 0, -1]
else:
int_part = int(parts[0])
dec_part_len = len(parts[1])
dec_part_val = int(parts[1])
key_val = [int_part, dec_part_len, dec_part_val]
return key_val
이런 식으로 미리 분리해놓으면, split(".")를 여러 번 하지 않아도 되고, 가독성도 조금 더 좋아집니다.
import sys
def parse_gpt_number(num_str: str):
"""
'GPT 기준'으로 수를 비교하기 위한 튜플을 반환한다.
규칙 요약:
1) 정수부(int_part)가 작은 것이 먼저.
2) 정수부가 같다면, 소수점이 없는 수가 더 작다.
3) 둘 다 소수점이 있을 때는, 소수점 뒷부분을 '정수'로 본 크기 비교.
반환되는 튜플의 구조:
( int_part, decimal_length, decimal_value )
- decimal_length: 소수점이 없으면 0, 있으면 len(소수점 뒷부분)
- decimal_value: 소수점이 없으면 -1, 그렇지 않으면 int(소수점 뒷부분)
"""
parts = num_str.split(".")
int_part = int(parts[0]) # 소수점 앞부분
# 소수점이 없는 경우
if len(parts) == 1:
return (int_part, 0, -1)
else:
# 소수점이 있는 경우
decimal_part = parts[1]
decimal_length = len(decimal_part)
decimal_value = int(decimal_part)
return (int_part, decimal_length, decimal_value)
def main():
data = sys.stdin.read().splitlines()
# 첫 줄은 수의 개수 N
N = int(data[0])
# 그 이후 줄들은 실제 수들
numbers = data[1:]
# 혹시나 공백 라인이 섞여있다면 제거 (문제 조건상 없어야 하지만 방어적 처리)
numbers = [num for num in numbers if num.strip() != ""]
# GPT 기준에 맞춰 정렬
numbers.sort(key=parse_gpt_number)
# 결과 출력
print("\n".join(numbers))
if __name__ == '__main__':
main()