本文将深度解析一个基于 Spring Boot 的 WebSocket 功能实现,该实现具备用户认证、会话管理、事件驱动等特性,结构清晰且易于扩展。
核心组件概览
该 WebSocket 功能主要由以下几个核心组件构成:
**WebSocketConfig**
: WebSocket 的主配置类,负责注册处理器和拦截器。**WebSocketAuthInterceptor**
: 握手阶段的认证拦截器,用于验证用户身份。**WebSocketEventHandler**
: 核心事件处理器,处理连接建立、消息接收和连接关闭等生命周期事件。**WebSocketSessionManager**
: 会话管理中心,用于跟踪和管理所有活跃的 WebSocket 连接。**WebSocketEvent**
: 自定义事件模型,用于在 WebSocket 的不同生命周期阶段发布事件,实现业务逻辑的解耦。- 业务层监听器 (
**UserServiceImpl**
): 监听并处理WebSocketEvent
,执行具体的业务逻辑。
1. 配置入口 (WebSocketConfig
)
这是 WebSocket 功能的起点。通过 @EnableWebSocket
注解开启支持,并实现 WebSocketConfigurer
接口来配置处理器和拦截器。
**registerWebSocketHandlers**
:- 注册了
WebSocketEventHandler
作为核心处理器,并映射到路径/websocket
。 - 【核心】 添加了
WebSocketAuthInterceptor
拦截器,确保所有到/websocket
的连接请求都先经过认证。 - 设置了
setAllowedOriginPatterns("*")
来允许跨域连接。
- 注册了
package com.sf.springtemplate.common.config;import com.sf.springtemplate.common.interceptor.WebSocketAuthInterceptor;
import com.sf.springtemplate.common.handler.WebSocketEventHandler;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;@Configuration
@EnableWebSocket // 开启Spring对WebSocket的支持
public class WebSocketConfig implements WebSocketConfigurer {private final WebSocketSessionManager sessionManager;private final ApplicationEventPublisher eventPublisher;private final WebSocketAuthInterceptor authInterceptor; // 注入我们自定义的认证拦截器public WebSocketConfig(WebSocketSessionManager sessionManager, ApplicationEventPublisher eventPublisher, WebSocketAuthInterceptor authInterceptor) {this.sessionManager = sessionManager;this.eventPublisher = eventPublisher;this.authInterceptor = authInterceptor;}@Overridepublic void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {registry// 1. 注册我们的核心事件处理器,并指定处理的路径为"/websocket".addHandler(new WebSocketEventHandler(sessionManager, eventPublisher), "/websocket")// 2. 【核心配置】为这个路径添加认证拦截器,所有/websocket的连接请求都会先经过它处理.addInterceptors(authInterceptor)// 3. 设置允许的跨域来源,"*"表示允许所有来源,在生产环境中应配置为具体的前端域名.setAllowedOriginPatterns("*");}
}
2. 连接认证 (WebSocketAuthInterceptor
)
在 WebSocket 握手阶段进行拦截,实现用户身份认证,只有认证通过的连接才会被建立。
**beforeHandshake**
:- 从请求 URL 的参数中获取
token
。 - 使用
JwtUtils
对token
进行解析和验证。 - 从
token
中获取userId
,并查询数据库以确认用户存在且状态正常。 - 【核心】 认证成功后,将
userId
存入attributes
中。这个attributes
会被传递给后续的WebSocketEventHandler
,使其能在连接建立时获取到已认证的用户信息。 - 如果认证失败,返回
false
,中断连接。
- 从请求 URL 的参数中获取
package com.sf.springtemplate.common.interceptor;import com.sf.springtemplate.common.util.JwtUtils;
import com.sf.springtemplate.entity.User;
import com.sf.springtemplate.mapper.UserMapper;
import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;import java.util.Map;
import java.util.Objects;@Component
@Slf4j
public class WebSocketAuthInterceptor implements HandshakeInterceptor {private final JwtUtils jwtUtils;private final UserMapper userMapper;public WebSocketAuthInterceptor(JwtUtils jwtUtils, UserMapper userMapper) {this.jwtUtils = jwtUtils;this.userMapper = userMapper;}@Overridepublic boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {String token = UriComponentsBuilder.fromUri(request.getURI()).build().getQueryParams().getFirst("token");if (token == null || token.trim().isEmpty()) {log.warn("WebSocket握手失败: URL中缺少token参数。");return false;}try {Claims claims = jwtUtils.parseToken(token);if (Objects.isNull(claims)) {log.warn("WebSocket握手失败: Token无效。");return false;}Integer userId = claims.get("userId", Integer.class);if (userId == null) {log.warn("WebSocket握手失败: Token中缺少userId。");return false;}User user = userMapper.selectById(userId);if (user == null || !user.getStatus()) {log.warn("WebSocket握手失败: 用户不存在或已被禁用, userId: {}", userId);return false;}attributes.put("userId", String.valueOf(user.getId()));log.info("WebSocket认证成功,用户ID: {}", userId);return true;} catch (Exception e) {log.error("WebSocket握手认证时发生异常: {}", e.getMessage());return false;}}@Overridepublic void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {// 握手后不做任何处理}
}
3. 会话管理 (WebSocketSessionManager
)
这是一个单例组件 (@Component
),负责在内存中统一管理所有活跃的 WebSocket 连接。
- 使用了三个
ConcurrentHashMap
来分别存储sessionId -> session
、userId -> sessionId
和sessionId -> userId
的映射关系,确保线程安全和高效查找。 **addSession**
: 添加新连接。**removeSession**
: 移除连接。**sendMessageToUser**
: 向指定用户发送消息。**broadcastMessage**
: 向所有在线用户广播消息。
package com.sf.springtemplate.service;import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;import java.io.IOException;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;@Component
public class WebSocketSessionManager {private static final Logger log = LoggerFactory.getLogger(WebSocketSessionManager.class);private static final ConcurrentHashMap<String, WebSocketSession> SESSIONS = new ConcurrentHashMap<>();private static final ConcurrentHashMap<String, String> USER_SESSIONS = new ConcurrentHashMap<>();private static final ConcurrentHashMap<String, String> SESSION_USERS = new ConcurrentHashMap<>();public void addSession(String userId, WebSocketSession session) {SESSIONS.put(session.getId(), session);USER_SESSIONS.put(userId, session.getId());SESSION_USERS.put(session.getId(), userId);}public String removeSession(WebSocketSession session) {String userId = SESSION_USERS.remove(session.getId());if (userId != null) {USER_SESSIONS.remove(userId);}SESSIONS.remove(session.getId());return userId;}public void sendMessageToUser(String userId, String message) {String sessionId = USER_SESSIONS.get(userId);if (sessionId != null) {WebSocketSession session = SESSIONS.get(sessionId);if (session != null && session.isOpen()) {try {session.sendMessage(new TextMessage(message));} catch (IOException e) {log.error("向用户 {} 发送消息失败: {}", userId, e.getMessage());}}}}public void broadcastMessage(String message) {log.info("开始广播消息: {}", message);int successCount = 0;for (WebSocketSession session : SESSIONS.values()) {if (session.isOpen()) {try {session.sendMessage(new TextMessage(message));successCount++;} catch (IOException e) {log.error("向会话 {} 广播消息失败: {}", session.getId(), e.getMessage());}}}log.info("消息广播完成,成功发送给 {} 个客户端", successCount);}
}
4. 事件处理与发布 (WebSocketEventHandler
& WebSocketEvent
)
WebSocketEventHandler
继承自 TextWebSocketHandler
,负责处理 WebSocket 的核心生命周期,并通过 ApplicationEventPublisher
将这些活动发布为 Spring 事件。
**afterConnectionEstablished**
: 连接成功后,从session.getAttributes()
中获取userId
(由WebSocketAuthInterceptor
存入),将会话添加到sessionManager
中,并发布USER_ONLINE
事件。**handleTextMessage**
: 收到消息后,发布MESSAGE_RECEIVED
事件。**afterConnectionClosed**
: 连接关闭后,从sessionManager
中移除会话,并发布USER_OFFLINE
事件。
WebSocketEvent
是一个自定义的 ApplicationEvent
,用于封装事件信息(如事件类型、用户ID、会话等),实现了业务逻辑与 WebSocket 底层处理的解耦。
package com.sf.springtemplate.common.handler;import com.sf.springtemplate.common.model.WebSocketEvent;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;public class WebSocketEventHandler extends TextWebSocketHandler {private static final Logger log = LoggerFactory.getLogger(WebSocketEventHandler.class);private final WebSocketSessionManager sessionManager;private final ApplicationEventPublisher eventPublisher;public WebSocketEventHandler(WebSocketSessionManager sessionManager, ApplicationEventPublisher eventPublisher) {this.sessionManager = sessionManager;this.eventPublisher = eventPublisher;}@Overridepublic void afterConnectionEstablished(WebSocketSession session) {String userId = (String) session.getAttributes().get("userId");if (userId != null) {sessionManager.addSession(userId, session);log.info("用户 {} 连接成功, 会话ID: {}, 当前总连接数: {}", userId, session.getId(), sessionManager.getActiveConnectionCount());eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.USER_ONLINE, userId, session));} else {// ...}}@Overrideprotected void handleTextMessage(WebSocketSession session, TextMessage message) {String userId = sessionManager.getUserId(session.getId());if (userId != null) {log.info("收到来自用户 {} 的消息: {}", userId, message.getPayload());eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.MESSAGE_RECEIVED, userId, session, message.getPayload()));}}@Overridepublic void afterConnectionClosed(WebSocketSession session, CloseStatus status) {String userId = sessionManager.removeSession(session);if (userId != null) {log.info("用户 {} 连接关闭, 原因: {}, 当前总连接数: {}", userId, status.getReason(), sessionManager.getActiveConnectionCount());eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.USER_OFFLINE, userId, session));}}
}
package com.sf.springtemplate.common.model;import lombok.Getter;
import org.springframework.context.ApplicationEvent;
import org.springframework.web.socket.WebSocketSession;@Getter
public class WebSocketEvent extends ApplicationEvent {private final Type type;private final String userId;private final WebSocketSession session;private final String message;public enum Type {USER_ONLINE,USER_OFFLINE,MESSAGE_RECEIVED}public WebSocketEvent(Object source, Type type, String userId, WebSocketSession session) {this(source, type, userId, session, null);}public WebSocketEvent(Object source, Type type, String userId, WebSocketSession session, String message) {super(source);this.type = type;this.userId = userId;this.session = session;this.message = message;}
}
5. 业务逻辑处理 (UserServiceImpl
)
在业务层(例如 UserServiceImpl
)中,可以非常方便地通过 @EventListener
注解来监听并处理前面发布的 WebSocketEvent
。
**handleWebSocketEvents**
:- 监听
WebSocketEvent
。 - 判断事件类型是
USER_ONLINE
还是USER_OFFLINE
。 - 执行相应的业务逻辑,例如,当用户上线时,调用
webSocketSessionManager.sendMessageToUser
发送一条欢迎消息。
- 监听
// 在 UserServiceImpl.java 中
import com.sf.springtemplate.common.model.WebSocketEvent;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.event.EventListener;
// ...@Service
@Slf4j
public class UserServiceImpl extends ServiceImpl<UserMapper, User> implements UserService {@Autowiredprivate WebSocketSessionManager webSocketSessionManager;//... 其他代码/*** 监听WebSocket事件的示例方法。*/@EventListenerpublic void handleWebSocketEvents(WebSocketEvent event) {if (event.getType() == WebSocketEvent.Type.USER_ONLINE) {String userId = event.getUserId();log.info("【业务层】监听到用户 {} 上线了!", userId);String welcomeMessage = "欢迎回来!您已成功连接到实时通知服务。";webSocketSessionManager.sendMessageToUser(userId, welcomeMessage);} else if (event.getType() == WebSocketEvent.Type.USER_OFFLINE) {log.info("【业务层】监听到用户 {} 离线了。", event.getUserId());// 可在此处添加用户离线后的业务处理}}
}
总结
这个 WebSocket 实现方案通过责任链模式(拦截器处理认证)和观察者模式(事件发布/监听机制)实现了高度的模块化和解耦。开发者可以轻松地在业务层监听 WebSocket 事件,而无需关心底层的连接管理和生命周期,从而实现干净、可维护的代码结构。