侧边栏壁纸
博主头像
昂洋编程 博主等级

鸟随鸾凤飞腾远,人伴贤良品自高

  • 累计撰写 71 篇文章
  • 累计创建 79 个标签
  • 累计收到 0 条评论

目 录CONTENT

文章目录

过滤器&拦截器&全局异常处理中获取Post请求body

Administrator
2023-04-12 / 0 评论 / 0 点赞 / 75 阅读 / 0 字 / 正在检测是否收录...
温馨提示:
本文最后更新于2024-06-14,若内容或图片失效,请留言反馈。 部分素材来自网络,若不小心影响到您的利益,请联系我们删除。

我们都知道在SpringMVC中,如果前端是POST请求,并且参数再请求body中,也就是以流的形式传参的话,这个流是只能读取一次的,首次读取之后流就会关闭

通常情况下我们在请求处理的前后都会有一些自定义的处理,都会拿请求参数做记录,比如xss过滤,权限拦截、IP限制、全局异常处理等等,所以有这么多的地方会用到,那么我们必须自己封装请求body对象然后流转

过滤器

过滤器这里以xss过滤为例,xss过滤主要是防止xss攻击和防止SQL注入
关于这块的实现,可以参考之前的文章: Springboot配置防XSS攻击&Sql注入(含Post请求、跳过文件上传)

拦截器

拦截器的优先级是低于过滤器的,这里以IP限制拦截器为例

定义请求封装

package com.zone.kinglims.config;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;

/**
 * @author zhangmy
 * @date 2023/3/22 11:34
 * @description Ip限制请求包装
 */
@Slf4j
public class IpLimitHttpServletRequestWrapper extends HttpServletRequestWrapper {

    /**
     * post请求体
     */
    private byte[] body;

    public byte[] getBody() {
        return body;
    }

    private String requestBodyString;
    private ByteArrayInputStream requestInputStream;

    public IpLimitHttpServletRequestWrapper() {
        super(null);
    }

    /**
     * 构造函数 - 获取post请求体
     * @param httpservletrequest
     * @throws IOException
     */
    public IpLimitHttpServletRequestWrapper(HttpServletRequest httpservletrequest) throws IOException {
        super(httpservletrequest);
        body = IOUtils.toByteArray(httpservletrequest.getInputStream());
        requestBodyString = new String(body, httpservletrequest.getCharacterEncoding());
        requestInputStream = new ByteArrayInputStream(body);
    }

    /**
     * 过滤请求体 json 格式的
     * @return
     */
    @Override
    public ServletInputStream getInputStream() {
        // 将请求体参数流转 -- 流读取一次就会消失,所以我们事先读取之后就存在byte数组里边方便流转
        return new ServletInputStream() {

            @Override
            public int read() {
                return requestInputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(requestInputStream, getCharacterEncoding()));
    }
}

上面请求体封装类中构造方法里读取了原Request中的流数据,然后存入变量方便在其他地方使用

编写拦截器

package com.zone.kinglims.config;

import com.alibaba.fastjson.JSONObject;
import com.zone.kinglims.common.security.entity.User;
import com.zone.kinglims.common.system.dict.entity.SysDictCache;
import com.zone.kinglims.common.system.dict.entity.SysDictDetail;
import com.zone.kinglims.utils.EmptyUtil;
import com.zone.kinglims.utils.IpCheckUtil;
import com.zone.kinglims.utils.system.SysBasic;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * @author zhangmy
 * @date 2023/03/22 9:17
 * @description 全局IP限制拦截器
 */
@Slf4j
@Component
public class IpLimitInterceptor extends HandlerInterceptorAdapter {

    @Autowired
    private SysDictCache dictCache;
    @Autowired
    private IpLimitConfigCache ipLimitConfigCache;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        log.info("全局IP限制拦截器start...");
        String ipLimitSwitch = dictCache.getDictValue("sys_conf", "ip_limit_switch");
        if ("OFF".equals(ipLimitSwitch)) {
            log.info("全局IP限制拦截器end,开关未打开...");
            return true;
        }
        String uri = request.getRequestURI();
        log.info("全局IP限制拦截器url: " + uri);
        // 不校验IP黑白名单的url -- 注意此url的配置尽量保持xss放行的url同步,xss过滤器用的requestWrapper是在此之前,如果没有包装请求,此处会拿不到请求体
        List<SysDictDetail> ipLimitConfigExcludeUrls = dictCache.getDetails("ip_limit_config_exclude_urls");
        if (EmptyUtil.isNotEmpty(ipLimitConfigExcludeUrls)) {
            Set<String> excludeUrls = ipLimitConfigExcludeUrls.stream().map(SysDictDetail::getValue).collect(Collectors.toSet());
            for (String excludeUrl : excludeUrls) {
                if (uri.contains(excludeUrl)) {
                    log.info("全局IP限制拦截器end,【" + uri + "】放行...");
                    return true;
                }
            }
        }
        // 使用封装号的servletRequestWrapper进行请求体中inputStream的读取和返还到HttpServletRequest中
        IpLimitHttpServletRequestWrapper requestWrapper = new IpLimitHttpServletRequestWrapper(request);
        boolean goOn;
        // 请求方式
        String requestMethod = request.getMethod();
        String USER_LOGIN_NAME = null;

        String token = request.getHeader("Authorization");
        ServletInputStream requestInputStream = requestWrapper.getInputStream();

        // token不为空
        if (EmptyUtil.isNotEmpty(token)) {
            User user = SysBasic.getUserByTokenRedis(token);
            if (EmptyUtil.isNotEmpty(user)) {
                goOn = checkUserIpConfig(user.getUSER_LOGIN_NAME(), requestWrapper);
            } else {
                goOn = checkIpConfig(requestMethod, requestWrapper, requestInputStream, USER_LOGIN_NAME);
            }
        // token为空,判断是否有传用户名
        } else {
            goOn = checkIpConfig(requestMethod, requestWrapper, requestInputStream, USER_LOGIN_NAME);
        }
        if (!goOn) {
            response.setHeader("Content-type", "application/json;charset=UTF-8");
            // 设置状态码
            response.setCharacterEncoding("UTF-8");
            response.setStatus(HttpStatus.OK.value());
            // 返回自定义的错误信息
            JSONObject respJson = new JSONObject();
            respJson.put("code", -1);
            respJson.put("status", -1);
            respJson.put("message", "当前网络禁止访问系统,请联系管理员");
            response.getWriter().write(respJson.toJSONString());
            return goOn;
        }
        return super.preHandle(requestWrapper, response, handler);
    }

    /**
     * 内部方法--读取请求参数判断ip是否限制
     * @param requestMethod
     * @param request
     * @param requestInputStream
     * @param USER_LOGIN_NAME
     * @return
     * @throws IOException
     */
    private Boolean checkIpConfig(String requestMethod, HttpServletRequest request, ServletInputStream requestInputStream, String USER_LOGIN_NAME) throws IOException {
        boolean checkResult = true;
        if ("GET".equals(requestMethod)) {
            USER_LOGIN_NAME = SysBasic.toTranStringByObject(request.getParameter("USER_LOGIN_NAME"));
        } else {
            String bodyStr = getBodyString(requestInputStream);
            if (EmptyUtil.isNotEmpty(bodyStr)) {
                if (bodyStr.startsWith("{") && bodyStr.endsWith("}")) {
                    JSONObject bodyObj = JSONObject.parseObject(bodyStr);
                    USER_LOGIN_NAME = bodyObj.getString("USER_LOGIN_NAME");
                }
            }
        }

        if (EmptyUtil.isNotEmpty(USER_LOGIN_NAME)) {
            checkResult = checkUserIpConfig(USER_LOGIN_NAME, request);
        }
        return checkResult;
    }

    /**
     * 获取请求体
     * @param requestInputStream
     * @return
     * @throws IOException
     */
    private String getBodyString(ServletInputStream requestInputStream) throws IOException {
        StringBuilder sb = new StringBuilder();
        try (BufferedReader isr = new BufferedReader(new InputStreamReader(requestInputStream, StandardCharsets.UTF_8));) {
            String line = "";
            while ((line = isr.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            throw e;
        }
        return sb.toString();
    }

    /**
     * 校验用户ip登录限制是否通过
     * @param USER_LOGIN_NAME
     * @return
     */
    private Boolean checkUserIpConfig(String USER_LOGIN_NAME, HttpServletRequest request) {
        // 查询用户绑定的ip黑白名单
        Set<Map<String, Object>> userIpConfigList = ipLimitConfigCache.getUserIpLimitCache(USER_LOGIN_NAME);
        // 没有配置黑白名单直接通行
        if (EmptyUtil.isEmpty(userIpConfigList)) {
            return true;
        }
        // 当前请求的ip
        String requestIp = SysBasic.getIpAddr(request);
        log.info("当前请求IP地址为:[{}]", requestIp);
        // 黑名单ip
        Set<String> excludeIpSet = userIpConfigList.stream().
                filter(map -> "拒绝".equals(SysBasic.toTranStringByObject(map.get("STRATEGY"))))
                .map(map2 -> SysBasic.toTranStringByObject(map2.get("IP_ADDRESS")))
                .collect(Collectors.toSet());
        log.info("当前用户[{}]限制的黑名单ip为:{}", USER_LOGIN_NAME, excludeIpSet.toString());
        // 白名单ip
        Set<String> includeIpSet = userIpConfigList.stream().
                filter(map -> "允许".equals(SysBasic.toTranStringByObject(map.get("STRATEGY"))))
                .map(map2 -> SysBasic.toTranStringByObject(map2.get("IP_ADDRESS")))
                .collect(Collectors.toSet());
        log.info("当前用户[{}]允许的白名单ip为:{}", USER_LOGIN_NAME, includeIpSet.toString());
        // 黑名单不为空
        if (EmptyUtil.isNotEmpty(excludeIpSet)) {
            // 在黑名单中 直接拒绝
            if (IpCheckUtil.isInclude(requestIp, IpCheckUtil.getAvailableIpList(excludeIpSet))) {
                return false;
                // 不在黑名单中
            } else {
                // 判断白名单是否为空
                if (EmptyUtil.isEmpty(includeIpSet)) {
                    return true;
                    // 黑名单不为空,并且在黑名单之中,白名单也不为空,那么必须在白名单中才能通过
                } else {
                    if (IpCheckUtil.isInclude(requestIp, IpCheckUtil.getAvailableIpList(includeIpSet))) {
                        return true;
                    } else {
                        return false;
                    }
                }
            }
            // 黑名单为空,这里一定有白名单
        } else {
            // 不在白名单中 直接拒绝
            if (!IpCheckUtil.isInclude(requestIp, IpCheckUtil.getAvailableIpList(includeIpSet))) {
                return false;
                // 在白名单中,通过
            } else {
                return true;
            }
        }
    }
}

这里边有一些自己的业务逻辑,这不重要,注意IpLimitHttpServletRequestWrapper requestWrapper = new IpLimitHttpServletRequestWrapper(request); 这一句代码,是将原来的request封装成新的request继续流转,这里有个技巧,可以通过断点的方式查看到原request是什么class,即可知道之前是否有其他请求体封装。

注册拦截器

package com.zone.kinglims.common.security.config;

import com.zone.kinglims.common.security.handler.CorsHandlerInterceptor;
import com.zone.kinglims.config.IpLimitInterceptor;
import com.zone.kinglims.config.PermissionCheckInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

@Configuration
public class WebMvcConfig implements WebMvcConfigurer {

    @Autowired
    private CorsHandlerInterceptor interceptor;
    @Autowired
    private PermissionCheckInterceptor permissionCheckInterceptor;
    @Autowired
    private IpLimitInterceptor ipLimitInterceptor;

    @Override
    public void addCorsMappings(CorsRegistry registry) {
        registry.addMapping("/**")
                .allowedOrigins("*")
                .allowedMethods("POST", "GET", "PUT", "OPTIONS", "DELETE")
                .maxAge(3600)
                .allowCredentials(true);
    }

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        // 跨域拦截器
        registry.addInterceptor(interceptor).addPathPatterns("/**");
        // 权限校验拦截器 -- excludePathPatterns标识不拦截的url
        registry.addInterceptor(permissionCheckInterceptor)
                .excludePathPatterns("/login/sampleLaboratory");
        // IP地址限制拦截器
        registry.addInterceptor(ipLimitInterceptor);
    }
}

可以看到这里有多个拦截器,但是其他拦截如果不读取request的请求体的,不需要封装原request

全局异常处理

package com.zone.kinglims.common.security.handler;

import com.alibaba.fastjson.JSONObject;
import com.zone.kinglims.common.security.entity.User;
import com.zone.kinglims.common.system.jdbc.handler.query.QueryTables;
import com.zone.kinglims.config.IpLimitHttpServletRequestWrapper;
import com.zone.kinglims.config.XssHttpServletRequestWrapper;
import com.zone.kinglims.utils.EmptyUtil;
import com.zone.kinglims.utils.system.SysBasic;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
 * @author zhangmy
 * @date 2023/3/23 16:19
 * @description 全局异常处理
 * 其中 @RestControllerAdvice 可以指定包,比如@RestControllerAdvice(basePackages = {"com.hugo.system.controller", "com.hugo.auth.controller"})
 */
@RestControllerAdvice
public class GlobalExceptionHandler {

    private final QueryTables queryTables;

    public GlobalExceptionHandler(QueryTables queryTables) {
        this.queryTables = queryTables;
    }

    @ExceptionHandler(value = { AccessDeniedException.class })
    public JSONObject exception(ServletRequest servletRequest, AccessDeniedException e) throws IOException {
        // 保存异常日志
        this.saveExceptionLogs(servletRequest, e);
        JSONObject result = new JSONObject();
        result.put("code", -1);
        result.put("status", -1);
        result.put("message", " 权限不足,无法访问该资源.");
        return result;
    }

    /**
     * 处理通用异常
     * @param e
     * @return
     */
    @ExceptionHandler(value = { Exception.class })
    public JSONObject exception(ServletRequest servletRequest, Exception e) {
        // 保存异常日志
        this.saveExceptionLogs(servletRequest, e);
        JSONObject result = new JSONObject();
        result.put("code", -1);
        result.put("status", -1);
        result.put("message", e.getMessage());
        return result;
    }

    /**
     * 内部方法 -- 保存异常日志
     * @param request
     * @param e
     */
    private void saveExceptionLogs(ServletRequest request, Exception e) {
        // 打印异常日志
        e.printStackTrace();
        // 异常日志对象map
        Map<String, Object> errorLogMap = new HashMap<>();
        String token;
        if (request instanceof XssHttpServletRequestWrapper) {
            XssHttpServletRequestWrapper requestWrapper = (XssHttpServletRequestWrapper) request;
            // 获取token
            token = requestWrapper.getHeader("Authorization");
            // 请求uri
            errorLogMap.put("REQUEST_URI", requestWrapper.getRequestURI());
            // 请求参数
            String requestBodyStr = new String(requestWrapper.getBody());
            errorLogMap.put("REQUEST_PARAM", requestBodyStr);
        } else if (request instanceof IpLimitHttpServletRequestWrapper) {
            IpLimitHttpServletRequestWrapper requestWrapper = (IpLimitHttpServletRequestWrapper) request;
            // 获取token
            token = requestWrapper.getHeader("Authorization");
            // 请求uri
            errorLogMap.put("REQUEST_URI", requestWrapper.getRequestURI());
            // 请求参数
            String requestBodyStr = new String(requestWrapper.getBody());
            errorLogMap.put("REQUEST_PARAM", requestBodyStr);
        } else {
            HttpServletRequestWrapper requestWrapper = (HttpServletRequestWrapper) request;
            // 获取token
            token = requestWrapper.getHeader("Authorization");
            // 请求uri
            errorLogMap.put("REQUEST_URI", requestWrapper.getRequestURI());
            errorLogMap.put("REQUEST_PARAM", "获取请求参数失败,request类型为:" + request.getClass().getName());
        }
        if (EmptyUtil.isNotEmpty(token)) {
            User user = SysBasic.getUserByTokenRedis(token);
            // 用户名
            String username = EmptyUtil.isNotEmpty(user) ? user.getNAME() : "";
            errorLogMap.put("USERNAME", username);
            // 实验室
            String loginCompany = EmptyUtil.isNotEmpty(user) ? user.getLOGINCOMPANY() : "";
            errorLogMap.put("LABORATORY", loginCompany);
        }
        // 异常名称
        errorLogMap.put("ERROR_NAME", e.getClass().getName());
        // 异常详情
        errorLogMap.put("ERROR_DETAIL", stackTraceToString(e));
        // 创建时间
        errorLogMap.put("CREATE_TIME", SysBasic.getNowTime());
        // 主键
        errorLogMap.put("ID", SysBasic.getUUID());
        // 记录日志
        queryTables.insertOne("error_log", errorLogMap);
    }

    /**
     * 转换异常信息为字符串
     * @param e
     * @return
     */
    private String stackTraceToString(Exception e) {
        StringBuffer stringBuffer = new StringBuffer();
        for (StackTraceElement stet : e.getStackTrace()) {
            stringBuffer.append(stet + "\n </br>");
        }
        String message = e.getClass().getName() + ":" + e.getMessage() + "\n </br>" + stringBuffer.toString();
        return message;
    }
}

注意上面的saveExceptionLogs方法,此方法里是根据原request的class类型从而获取requestBodu的参数,最后将异常记录到日志表中。因为我这里最前边是xssFilter的封装,所以先判断XssHttpServletRequestWrapper,然后再判断的IpLimitHttpServletRequestWrapper,对于其他的位置wrapper,这里暂不记录请求参数,因为有其他封装的这里很有可能是拿不到请求body的,同样的可以和上面一样打断点查看员request的class类型,这里我只是记录的原request的class类型

如果前面没有自定封装requestWrapper,想要在全局异常处理这里拿到请求body参数的话,也可以用spring自带的ContentCachingRequestWrapper,只需要简单的配置即可使用,参考: https://blog.csdn.net/u012946310/article/details/115367699

0

评论区