KMP를 사용하지만 KMP의 일부만 사용한다고 해야하나...
문자열의 패턴을 찾아 이 패턴이 반복되는 지를 확인하는 점에서 KMP임을 확인할 수 있다.
그러나 평범한 KMP문제는 아니다.
KMP문제는 주어진 패턴이 주어진 string내에 있는지를 확인하는, (또는 index를 반환하는) O(n)의 알고리즘이지만
해당 "문자열제곱"문제는 패턴이 따로 주어지지 않거니와 하나의 패턴이 문자열에 한번만 등장하는 것이 아닌 최대 몇번 등장할 수 있는지를 세어야 한다.
얼핏보면 KMP를 패턴 가지수마다 적용해야할 것 같은데
(이경우 n^2가 된다) n이 백만이하이기 때문에 반드시 O(N)내에 해결되어야 한다.
+) string 길이 / 패턴 길이 임을 어느정도 생각하고 있었지만 이부분에서 나도 해설을 참고했다.
KMP알고리즘을 구현하기 위해서는 pattern과 string을 비교하기 전에 먼저 패턴에 대해 table을 작성하는 과정이 필요하다
private static int [] findPattern(char[] list) {
int n = list.length;
int [] table = new int[n];
int idx=0;
for(int i=1; i<n; i++){
while(idx>0&&list[i] != list[idx]){
idx = table[idx-1];
}
if(list[i]==list[idx]){
idx+=1;
table[i]=idx;
}
}
return table;
}
위의 코드가 바로 그에 해당하는 부분이다.
요약하자면 pattern의 접두사==접미사가 일치하는 최대의 길이를 table[i]에 담은 것으로, table[i]에 들은 크기의 접두사, 접미사가 똑같은 형태인 패턴인 것을 확인할 수 있다.
ex) abc 'd' abc 의 경우, 접두사==접미사==abc이므로 table[i]=3
위 문제는 이 패턴 찾는 부분을 이용하여 풀 수 있다.
table[length-1]에는 주어진 string의 최대 접두사 길이를 알 수 있다.
string(0)부터 string(length-1) - table[length-1] 에는 우리가 원하던 string의 최대길이 패턴이 나타나는 것이다.
(패턴이 시작할때, table에서는 0, 0, ..으로 나타남. 즉 이부분의 길이를 구하는 것)
그렇다면 패턴의 반복횟수는 string length / (string length - table[length-1]) 로 구할 수 있게 된다.
물론 "", a^0이 들어오는 경우, aaaaaab 같이 길이가 홀수, table은 012345(0)이되는 경우에는 i/i=0이므로 예외처리를 해야한다.
전체 코드 :
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Scanner;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String temp = "";
while (!(temp = br.readLine()).equals(".")) {
char[] list = temp.toCharArray();
int i = list.length;
if (i == 0) System.out.println(0);
else {
int p[] = findPattern(list);
// System.out.println(temp);
// for (int j = 0; j < i; j++) {
// System.out.print(p[j] + " ");
// }
// System.out.println();
if(i%(i - p[i - 1])!=0){
//aaaaaab : 7%(7-6) = 1
System.out.println(1);
}else {
System.out.println(i / (i - p[i - 1]));
}
}
}
}
private static int [] findPattern(char[] list) {
int n = list.length;
int [] table = new int[n];
int idx=0;
for(int i=1; i<n; i++){
while(idx>0&&list[i] != list[idx]){
idx = table[idx-1];
}
if(list[i]==list[idx]){
idx+=1;
table[i]=idx;
}
}
return table;
}
}