我们都知道在SpringMVC中,如果前端是POST请求,并且参数再请求body中,也就是以流的形式传参的话,这个流是只能读取一次的,首次读取之后流就会关闭
通常情况下我们在请求处理的前后都会有一些自定义的处理,都会拿请求参数做记录,比如xss过滤,权限拦截、IP限制、全局异常处理等等,所以有这么多的地方会用到,那么我们必须自己封装请求body对象然后流转
过滤器
过滤器这里以xss过滤为例,xss过滤主要是防止xss攻击和防止SQL注入
关于这块的实现,可以参考之前的文章: Springboot配置防XSS攻击&Sql注入(含Post请求、跳过文件上传)
拦截器
拦截器的优先级是低于过滤器的,这里以IP限制拦截器为例
定义请求封装
package com.zone.kinglims.config;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
/**
* @author zhangmy
* @date 2023/3/22 11:34
* @description Ip限制请求包装
*/
@Slf4j
public class IpLimitHttpServletRequestWrapper extends HttpServletRequestWrapper {
/**
* post请求体
*/
private byte[] body;
public byte[] getBody() {
return body;
}
private String requestBodyString;
private ByteArrayInputStream requestInputStream;
public IpLimitHttpServletRequestWrapper() {
super(null);
}
/**
* 构造函数 - 获取post请求体
* @param httpservletrequest
* @throws IOException
*/
public IpLimitHttpServletRequestWrapper(HttpServletRequest httpservletrequest) throws IOException {
super(httpservletrequest);
body = IOUtils.toByteArray(httpservletrequest.getInputStream());
requestBodyString = new String(body, httpservletrequest.getCharacterEncoding());
requestInputStream = new ByteArrayInputStream(body);
}
/**
* 过滤请求体 json 格式的
* @return
*/
@Override
public ServletInputStream getInputStream() {
// 将请求体参数流转 -- 流读取一次就会消失,所以我们事先读取之后就存在byte数组里边方便流转
return new ServletInputStream() {
@Override
public int read() {
return requestInputStream.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(requestInputStream, getCharacterEncoding()));
}
}
上面请求体封装类中构造方法里读取了原Request中的流数据,然后存入变量方便在其他地方使用
编写拦截器
package com.zone.kinglims.config;
import com.alibaba.fastjson.JSONObject;
import com.zone.kinglims.common.security.entity.User;
import com.zone.kinglims.common.system.dict.entity.SysDictCache;
import com.zone.kinglims.common.system.dict.entity.SysDictDetail;
import com.zone.kinglims.utils.EmptyUtil;
import com.zone.kinglims.utils.IpCheckUtil;
import com.zone.kinglims.utils.system.SysBasic;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* @author zhangmy
* @date 2023/03/22 9:17
* @description 全局IP限制拦截器
*/
@Slf4j
@Component
public class IpLimitInterceptor extends HandlerInterceptorAdapter {
@Autowired
private SysDictCache dictCache;
@Autowired
private IpLimitConfigCache ipLimitConfigCache;
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
log.info("全局IP限制拦截器start...");
String ipLimitSwitch = dictCache.getDictValue("sys_conf", "ip_limit_switch");
if ("OFF".equals(ipLimitSwitch)) {
log.info("全局IP限制拦截器end,开关未打开...");
return true;
}
String uri = request.getRequestURI();
log.info("全局IP限制拦截器url: " + uri);
// 不校验IP黑白名单的url -- 注意此url的配置尽量保持xss放行的url同步,xss过滤器用的requestWrapper是在此之前,如果没有包装请求,此处会拿不到请求体
List<SysDictDetail> ipLimitConfigExcludeUrls = dictCache.getDetails("ip_limit_config_exclude_urls");
if (EmptyUtil.isNotEmpty(ipLimitConfigExcludeUrls)) {
Set<String> excludeUrls = ipLimitConfigExcludeUrls.stream().map(SysDictDetail::getValue).collect(Collectors.toSet());
for (String excludeUrl : excludeUrls) {
if (uri.contains(excludeUrl)) {
log.info("全局IP限制拦截器end,【" + uri + "】放行...");
return true;
}
}
}
// 使用封装号的servletRequestWrapper进行请求体中inputStream的读取和返还到HttpServletRequest中
IpLimitHttpServletRequestWrapper requestWrapper = new IpLimitHttpServletRequestWrapper(request);
boolean goOn;
// 请求方式
String requestMethod = request.getMethod();
String USER_LOGIN_NAME = null;
String token = request.getHeader("Authorization");
ServletInputStream requestInputStream = requestWrapper.getInputStream();
// token不为空
if (EmptyUtil.isNotEmpty(token)) {
User user = SysBasic.getUserByTokenRedis(token);
if (EmptyUtil.isNotEmpty(user)) {
goOn = checkUserIpConfig(user.getUSER_LOGIN_NAME(), requestWrapper);
} else {
goOn = checkIpConfig(requestMethod, requestWrapper, requestInputStream, USER_LOGIN_NAME);
}
// token为空,判断是否有传用户名
} else {
goOn = checkIpConfig(requestMethod, requestWrapper, requestInputStream, USER_LOGIN_NAME);
}
if (!goOn) {
response.setHeader("Content-type", "application/json;charset=UTF-8");
// 设置状态码
response.setCharacterEncoding("UTF-8");
response.setStatus(HttpStatus.OK.value());
// 返回自定义的错误信息
JSONObject respJson = new JSONObject();
respJson.put("code", -1);
respJson.put("status", -1);
respJson.put("message", "当前网络禁止访问系统,请联系管理员");
response.getWriter().write(respJson.toJSONString());
return goOn;
}
return super.preHandle(requestWrapper, response, handler);
}
/**
* 内部方法--读取请求参数判断ip是否限制
* @param requestMethod
* @param request
* @param requestInputStream
* @param USER_LOGIN_NAME
* @return
* @throws IOException
*/
private Boolean checkIpConfig(String requestMethod, HttpServletRequest request, ServletInputStream requestInputStream, String USER_LOGIN_NAME) throws IOException {
boolean checkResult = true;
if ("GET".equals(requestMethod)) {
USER_LOGIN_NAME = SysBasic.toTranStringByObject(request.getParameter("USER_LOGIN_NAME"));
} else {
String bodyStr = getBodyString(requestInputStream);
if (EmptyUtil.isNotEmpty(bodyStr)) {
if (bodyStr.startsWith("{") && bodyStr.endsWith("}")) {
JSONObject bodyObj = JSONObject.parseObject(bodyStr);
USER_LOGIN_NAME = bodyObj.getString("USER_LOGIN_NAME");
}
}
}
if (EmptyUtil.isNotEmpty(USER_LOGIN_NAME)) {
checkResult = checkUserIpConfig(USER_LOGIN_NAME, request);
}
return checkResult;
}
/**
* 获取请求体
* @param requestInputStream
* @return
* @throws IOException
*/
private String getBodyString(ServletInputStream requestInputStream) throws IOException {
StringBuilder sb = new StringBuilder();
try (BufferedReader isr = new BufferedReader(new InputStreamReader(requestInputStream, StandardCharsets.UTF_8));) {
String line = "";
while ((line = isr.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
throw e;
}
return sb.toString();
}
/**
* 校验用户ip登录限制是否通过
* @param USER_LOGIN_NAME
* @return
*/
private Boolean checkUserIpConfig(String USER_LOGIN_NAME, HttpServletRequest request) {
// 查询用户绑定的ip黑白名单
Set<Map<String, Object>> userIpConfigList = ipLimitConfigCache.getUserIpLimitCache(USER_LOGIN_NAME);
// 没有配置黑白名单直接通行
if (EmptyUtil.isEmpty(userIpConfigList)) {
return true;
}
// 当前请求的ip
String requestIp = SysBasic.getIpAddr(request);
log.info("当前请求IP地址为:[{}]", requestIp);
// 黑名单ip
Set<String> excludeIpSet = userIpConfigList.stream().
filter(map -> "拒绝".equals(SysBasic.toTranStringByObject(map.get("STRATEGY"))))
.map(map2 -> SysBasic.toTranStringByObject(map2.get("IP_ADDRESS")))
.collect(Collectors.toSet());
log.info("当前用户[{}]限制的黑名单ip为:{}", USER_LOGIN_NAME, excludeIpSet.toString());
// 白名单ip
Set<String> includeIpSet = userIpConfigList.stream().
filter(map -> "允许".equals(SysBasic.toTranStringByObject(map.get("STRATEGY"))))
.map(map2 -> SysBasic.toTranStringByObject(map2.get("IP_ADDRESS")))
.collect(Collectors.toSet());
log.info("当前用户[{}]允许的白名单ip为:{}", USER_LOGIN_NAME, includeIpSet.toString());
// 黑名单不为空
if (EmptyUtil.isNotEmpty(excludeIpSet)) {
// 在黑名单中 直接拒绝
if (IpCheckUtil.isInclude(requestIp, IpCheckUtil.getAvailableIpList(excludeIpSet))) {
return false;
// 不在黑名单中
} else {
// 判断白名单是否为空
if (EmptyUtil.isEmpty(includeIpSet)) {
return true;
// 黑名单不为空,并且在黑名单之中,白名单也不为空,那么必须在白名单中才能通过
} else {
if (IpCheckUtil.isInclude(requestIp, IpCheckUtil.getAvailableIpList(includeIpSet))) {
return true;
} else {
return false;
}
}
}
// 黑名单为空,这里一定有白名单
} else {
// 不在白名单中 直接拒绝
if (!IpCheckUtil.isInclude(requestIp, IpCheckUtil.getAvailableIpList(includeIpSet))) {
return false;
// 在白名单中,通过
} else {
return true;
}
}
}
}
这里边有一些自己的业务逻辑,这不重要,注意IpLimitHttpServletRequestWrapper requestWrapper = new IpLimitHttpServletRequestWrapper(request); 这一句代码,是将原来的request封装成新的request继续流转,这里有个技巧,可以通过断点的方式查看到原request是什么class,即可知道之前是否有其他请求体封装。
注册拦截器
package com.zone.kinglims.common.security.config;
import com.zone.kinglims.common.security.handler.CorsHandlerInterceptor;
import com.zone.kinglims.config.IpLimitInterceptor;
import com.zone.kinglims.config.PermissionCheckInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
@Autowired
private CorsHandlerInterceptor interceptor;
@Autowired
private PermissionCheckInterceptor permissionCheckInterceptor;
@Autowired
private IpLimitInterceptor ipLimitInterceptor;
@Override
public void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/**")
.allowedOrigins("*")
.allowedMethods("POST", "GET", "PUT", "OPTIONS", "DELETE")
.maxAge(3600)
.allowCredentials(true);
}
@Override
public void addInterceptors(InterceptorRegistry registry) {
// 跨域拦截器
registry.addInterceptor(interceptor).addPathPatterns("/**");
// 权限校验拦截器 -- excludePathPatterns标识不拦截的url
registry.addInterceptor(permissionCheckInterceptor)
.excludePathPatterns("/login/sampleLaboratory");
// IP地址限制拦截器
registry.addInterceptor(ipLimitInterceptor);
}
}
可以看到这里有多个拦截器,但是其他拦截如果不读取request的请求体的,不需要封装原request
全局异常处理
package com.zone.kinglims.common.security.handler;
import com.alibaba.fastjson.JSONObject;
import com.zone.kinglims.common.security.entity.User;
import com.zone.kinglims.common.system.jdbc.handler.query.QueryTables;
import com.zone.kinglims.config.IpLimitHttpServletRequestWrapper;
import com.zone.kinglims.config.XssHttpServletRequestWrapper;
import com.zone.kinglims.utils.EmptyUtil;
import com.zone.kinglims.utils.system.SysBasic;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
* @author zhangmy
* @date 2023/3/23 16:19
* @description 全局异常处理
* 其中 @RestControllerAdvice 可以指定包,比如@RestControllerAdvice(basePackages = {"com.hugo.system.controller", "com.hugo.auth.controller"})
*/
@RestControllerAdvice
public class GlobalExceptionHandler {
private final QueryTables queryTables;
public GlobalExceptionHandler(QueryTables queryTables) {
this.queryTables = queryTables;
}
@ExceptionHandler(value = { AccessDeniedException.class })
public JSONObject exception(ServletRequest servletRequest, AccessDeniedException e) throws IOException {
// 保存异常日志
this.saveExceptionLogs(servletRequest, e);
JSONObject result = new JSONObject();
result.put("code", -1);
result.put("status", -1);
result.put("message", " 权限不足,无法访问该资源.");
return result;
}
/**
* 处理通用异常
* @param e
* @return
*/
@ExceptionHandler(value = { Exception.class })
public JSONObject exception(ServletRequest servletRequest, Exception e) {
// 保存异常日志
this.saveExceptionLogs(servletRequest, e);
JSONObject result = new JSONObject();
result.put("code", -1);
result.put("status", -1);
result.put("message", e.getMessage());
return result;
}
/**
* 内部方法 -- 保存异常日志
* @param request
* @param e
*/
private void saveExceptionLogs(ServletRequest request, Exception e) {
// 打印异常日志
e.printStackTrace();
// 异常日志对象map
Map<String, Object> errorLogMap = new HashMap<>();
String token;
if (request instanceof XssHttpServletRequestWrapper) {
XssHttpServletRequestWrapper requestWrapper = (XssHttpServletRequestWrapper) request;
// 获取token
token = requestWrapper.getHeader("Authorization");
// 请求uri
errorLogMap.put("REQUEST_URI", requestWrapper.getRequestURI());
// 请求参数
String requestBodyStr = new String(requestWrapper.getBody());
errorLogMap.put("REQUEST_PARAM", requestBodyStr);
} else if (request instanceof IpLimitHttpServletRequestWrapper) {
IpLimitHttpServletRequestWrapper requestWrapper = (IpLimitHttpServletRequestWrapper) request;
// 获取token
token = requestWrapper.getHeader("Authorization");
// 请求uri
errorLogMap.put("REQUEST_URI", requestWrapper.getRequestURI());
// 请求参数
String requestBodyStr = new String(requestWrapper.getBody());
errorLogMap.put("REQUEST_PARAM", requestBodyStr);
} else {
HttpServletRequestWrapper requestWrapper = (HttpServletRequestWrapper) request;
// 获取token
token = requestWrapper.getHeader("Authorization");
// 请求uri
errorLogMap.put("REQUEST_URI", requestWrapper.getRequestURI());
errorLogMap.put("REQUEST_PARAM", "获取请求参数失败,request类型为:" + request.getClass().getName());
}
if (EmptyUtil.isNotEmpty(token)) {
User user = SysBasic.getUserByTokenRedis(token);
// 用户名
String username = EmptyUtil.isNotEmpty(user) ? user.getNAME() : "";
errorLogMap.put("USERNAME", username);
// 实验室
String loginCompany = EmptyUtil.isNotEmpty(user) ? user.getLOGINCOMPANY() : "";
errorLogMap.put("LABORATORY", loginCompany);
}
// 异常名称
errorLogMap.put("ERROR_NAME", e.getClass().getName());
// 异常详情
errorLogMap.put("ERROR_DETAIL", stackTraceToString(e));
// 创建时间
errorLogMap.put("CREATE_TIME", SysBasic.getNowTime());
// 主键
errorLogMap.put("ID", SysBasic.getUUID());
// 记录日志
queryTables.insertOne("error_log", errorLogMap);
}
/**
* 转换异常信息为字符串
* @param e
* @return
*/
private String stackTraceToString(Exception e) {
StringBuffer stringBuffer = new StringBuffer();
for (StackTraceElement stet : e.getStackTrace()) {
stringBuffer.append(stet + "\n </br>");
}
String message = e.getClass().getName() + ":" + e.getMessage() + "\n </br>" + stringBuffer.toString();
return message;
}
}
注意上面的saveExceptionLogs方法,此方法里是根据原request的class类型从而获取requestBodu的参数,最后将异常记录到日志表中。因为我这里最前边是xssFilter的封装,所以先判断XssHttpServletRequestWrapper,然后再判断的IpLimitHttpServletRequestWrapper,对于其他的位置wrapper,这里暂不记录请求参数,因为有其他封装的这里很有可能是拿不到请求body的,同样的可以和上面一样打断点查看员request的class类型,这里我只是记录的原request的class类型
如果前面没有自定封装requestWrapper,想要在全局异常处理这里拿到请求body参数的话,也可以用spring自带的ContentCachingRequestWrapper,只需要简单的配置即可使用,参考: https://blog.csdn.net/u012946310/article/details/115367699
评论区