<filter>
<filter-name>XSS</filter-name>
<!-- 해당 필터 클래스 파일 위치 -->
<filter-class>com.example.common.XSSFilter</filter-class>
</filter>
<filter-mapping>
<filter-name>XSS</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
public class XSSFilter implements Filter {
private FilterConfig filterConfig;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
this.filterConfig = filterConfig;
}
@Override
public void destroy() {
this.filterConfig = null;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
chain.doFilter(new XSSFilterWrapper((HttpServletRequest) request), response);
}
}
public class XSSFilterWrapper extends HttpServletRequestWrapper {
private byte[] requestBody;
public XSSFilterWrapper(HttpServletRequest request) {
super(request);
try {
InputStream inputStream = request.getInputStream();
this.requestBody = replaceXSS(IOUtils.toByteArray(inputStream));
} catch (Exception e) {
e.printStackTrace();
}
}
private byte[] replaceXSS(byte[] data) {
String strData = new String(data);
strData = strData.replaceAll("\\<", "<").replaceAll("\\>", ">").replaceAll("\\(", "(").replaceAll("\\)", ")");
byte[] byteData = strData.getBytes();
return byteData;
}
private String replaceXSS(String value) {
if(value != null) {
value = value.replaceAll("\\<", "<").replaceAll("\\>", ">").replaceAll("\\(", "(").replaceAll("\\)", ")");
}
return value;
}
@Override
public ServletInputStream getInputStream() throws IOException {
if ( this.requestBody == null ) {
return super.getInputStream();
}
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(this.requestBody);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public boolean isReady() {
return false;
}
@Override
public boolean isFinished() {
return false;
}
};
}
@Override
public String getQueryString() {
return replaceXSS(super.getQueryString());
}
@Override
public String getParameter(String name) {
return replaceXSS(super.getParameter(name));
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> params = super.getParameterMap();
if(params != null) {
params.forEach((key, value) -> {
for(int i=0; i < value.length; i++) {
value[i] = replaceXSS(value[i]);
}
}) ;
}
return params;
}
@Override
public String[] getParameterValues(String name) {
String[] params = super.getParameterValues(name);
if(params != null) {
for(int i=0; i<params.length; i++) {
params[i] = replaceXSS(params[i]);
}
}
return params;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(this.getInputStream(), "UTF-8"));
}
}
public class XSSFilterWrapper extends HttpServletRequestWrapper {
private byte[] requestBody;
private boolean hasReadRequestBody = false;
public XSSFilterWrapper(HttpServletRequest request) {
super(request);
}
private void readRequestBody() {
if (!hasReadRequestBody) {
try {
InputStream inputStream = super.getInputStream();
this.requestBody = replaceXSS(IOUtils.toByteArray(inputStream));
hasReadRequestBody = true;
} catch (IOException e) {
e.printStackTrace();
}
}
}
private byte[] replaceXSS(byte[] data) {
String strData = new String(data);
strData = strData.replaceAll("\\<", "<").replaceAll("\\>", ">").replaceAll("\\(", "(").replaceAll("\\)", ")");
byte[] byteData = strData.getBytes();
return byteData;
}
private String replaceXSS(String value) {
if(value != null) {
value = value.replaceAll("\\<", "<").replaceAll("\\>", ">").replaceAll("\\(", "(").replaceAll("\\)", ")");
}
return value;
}
@Override
public ServletInputStream getInputStream() throws IOException {
readRequestBody();
if ( this.requestBody == null ) {
return super.getInputStream();
}
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(this.requestBody);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public boolean isReady() {
return false;
}
@Override
public boolean isFinished() {
return false;
}
};
}
@Override
public String getQueryString() {
return replaceXSS(super.getQueryString());
}
@Override
public String getParameter(String name) {
return replaceXSS(super.getParameter(name));
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> params = super.getParameterMap();
if(params != null) {
params.forEach((key, value) -> {
for(int i=0; i < value.length; i++) {
value[i] = replaceXSS(value[i]);
}
}) ;
}
return params;
}
@Override
public String[] getParameterValues(String name) {
String[] params = super.getParameterValues(name);
if(params != null) {
for(int i=0; i<params.length; i++) {
params[i] = replaceXSS(params[i]);
}
}
return params;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(this.getInputStream(), "UTF-8"));
}
}
private static Pattern[] patterns = new Pattern[] {
Pattern.compile("<script>(.*?)</script>", Pattern.CASE_INSENSITIVE),
Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'",
Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"",
Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
Pattern.compile("</script>", Pattern.CASE_INSENSITIVE),
Pattern.compile("<script(.*?)>", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
Pattern.compile("eval\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
Pattern.compile("expression\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
Pattern.compile("javascript:", Pattern.CASE_INSENSITIVE),
Pattern.compile("vbscript:", Pattern.CASE_INSENSITIVE),
Pattern.compile("onload(.*?)=", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL)
};
private String stripXSS(String value) {
if(value != null) {
value = value.replaceAll("\0", "");
for(Pattern scriptPattern : patterns){
if(scriptPattern.matcher(value).matches()){
value = value.replaceAll("<", "<").replaceAll(">", ">");
}
}
value = value.replaceAll("<", "<").replaceAll(">", ">").replaceAll("'","'");
}
return value;
}
참고 자료
1. https://velog.io/@ch200203/%ED%94%84%EB%A1%9C%EC%A0%9D%ED%8A%B8%EC%97%90-XSS-%EC%A0%81%EC%9A%A9%ED%95%98%EA%B8%B0
2. https://hello-backend.tistory.com/168