feat(websocket): 重构 WebSocket 实现数据推送

- 新增 DataPushTask 类实现定时推送数据- 重构 WebSocketConfig 配置类
- 新增 DeviceWebSocketHandler、RecordWebSocketHandler 和 ProductWebSocketHandler 处理不同类型的 WebSocket 连接
- 移除 GenericWebSocketHandler 和 WebSocketInterceptor 类
- 更新 IotGatewayApplication 启用定时任务
- 调整 SpringSecurityConfig 允许 WebSocket 请求通过
This commit is contained in:
zhuangtianxiang 2025-03-24 20:06:43 +08:00
parent 42a9f9490f
commit 11d972e875
9 changed files with 131 additions and 150 deletions

View File

@ -3,9 +3,11 @@ package com.zsc.edu.gateway;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.EnableAspectJAutoProxy;
import org.springframework.scheduling.annotation.EnableScheduling;
@SpringBootApplication
@EnableAspectJAutoProxy
@EnableScheduling
public class IotGatewayApplication {
public static void main(String[] args) {

View File

@ -1,61 +0,0 @@
package com.zsc.edu.gateway.framework.message.websocket;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
/**
* @author zhuang
*/
@Component
public class GenericWebSocketHandler extends TextWebSocketHandler {
private final ExecutorService executorService = Executors.newCachedThreadPool();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// session 中获取业务逻辑的 Supplier
Supplier<Object> dataSupplier = (Supplier<Object>) session.getAttributes().get("dataSupplier");
if (dataSupplier == null) {
session.close();
return;
}
AtomicBoolean isCompleted = new AtomicBoolean(false);
executorService.execute(() -> {
try {
while (!isCompleted.get()) {
Object data = dataSupplier.get();
if (data == null) {
break;
}
if (!isCompleted.get()) {
session.sendMessage(new TextMessage(data.toString()));
}
Thread.sleep(5000); // 每隔 5 秒发送一次数据
}
} catch (IOException | InterruptedException e) {
try {
session.close();
} catch (IOException ex) {
ex.printStackTrace();
}
} finally {
try {
session.close();
} catch (IOException e) {
e.printStackTrace();
}
}
});
}
}

View File

@ -1,27 +0,0 @@
package com.zsc.edu.gateway.framework.message.websocket;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
/**
* @author zhuang
*/
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
private final GenericWebSocketHandler genericWebSocketHandler;
public WebSocketConfig(GenericWebSocketHandler genericWebSocketHandler) {
this.genericWebSocketHandler = genericWebSocketHandler;
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(genericWebSocketHandler, "/api/rest/device/ws/device/status", "/api/rest/device/ws/record/status", "/api/rest/product/ws/product/status")
.setAllowedOrigins("*")
.addInterceptors(new WebSocketInterceptor());
}
}

View File

@ -1,58 +0,0 @@
package com.zsc.edu.gateway.framework.message.websocket;
import com.zsc.edu.gateway.modules.iot.device.service.DeviceService;
import com.zsc.edu.gateway.modules.iot.product.service.ProductService;
import com.zsc.edu.gateway.modules.iot.record.service.RecordDataService;
import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import java.util.Map;
import java.util.function.Supplier;
/**
* @author lenovo
*/
public class WebSocketInterceptor implements HandshakeInterceptor {
@Resource
private DeviceService deviceService;
@Resource
private RecordDataService recordService;
@Resource
private ProductService productService;
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) {
if (request instanceof ServletServerHttpRequest) {
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
// 获取完整路径
String path = servletRequest.getRequestURI();
// 根据路径设置不同的业务逻辑 Supplier
switch (path) {
case "/api/rest/device/ws/device/status":
attributes.put("dataSupplier", (Supplier<String>) () -> String.valueOf(deviceService.status()));
break;
case "/api/rest/device/ws/record/status":
attributes.put("dataSupplier", (Supplier<String>) () -> String.valueOf(recordService.recordDataStatus()));
break;
case "/api/rest/product/ws/product/status":
attributes.put("dataSupplier", (Supplier<String>) () -> String.valueOf(productService.status()));
break;
default:
attributes.put("dataSupplier", (Supplier<String>) () -> "Unknown path: " + path);
break;
}
}
return true;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
// 握手完成后不需要额外操作
}
}

View File

@ -88,6 +88,7 @@ public class SpringSecurityConfig {
.requestMatchers(HttpMethod.GET, "/api/rest/user/menu","/api/rest/user/register","/api/rest/user/send-email").permitAll()
.requestMatchers(HttpMethod.POST, "/api/rest/user/login","/api/rest/user/register").permitAll()
.requestMatchers("/api/rest/user/me").permitAll()
.requestMatchers("/api/rest/ws/**").permitAll()
.requestMatchers("/api/**").authenticated()
)
// 不用注解直接通过判断路径实现动态访问权限
@ -144,7 +145,7 @@ public class SpringSecurityConfig {
.rememberMe(rememberMe -> rememberMe
.userDetailsService(userDetailsService)
.tokenRepository(persistentTokenRepository()))
.csrf(csrf -> csrf.ignoringRequestMatchers("/api/internal/**", "/api/rest/user/logout","/api/rest/user/register"))
.csrf(csrf -> csrf.ignoringRequestMatchers("/api/internal/**", "/api/rest/user/logout", "/api/rest/user/register", "/api/rest/ws/**"))
.sessionManagement(session -> session
.maximumSessions(3)
.sessionRegistry(sessionRegistry)

View File

@ -0,0 +1,46 @@
package com.zsc.edu.gateway.framework.websocket;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
public abstract class BaseWebSocketHandler extends TextWebSocketHandler {
private final Set<WebSocketSession> sessions = new CopyOnWriteArraySet<>();
@Override
public void afterConnectionEstablished(WebSocketSession session) {
sessions.add(session);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
sessions.remove(session);
}
public void broadcast(String message) {
sessions.removeIf(session -> {
if (!session.isOpen()) return true;
try {
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
// 生产环境应使用日志记录
e.printStackTrace();
}
return false;
});
}
}
class DeviceWebSocketHandler extends BaseWebSocketHandler {
}
class ProductWebSocketHandler extends BaseWebSocketHandler {
}
class RecordWebSocketHandler extends BaseWebSocketHandler {
}

View File

@ -0,0 +1,41 @@
package com.zsc.edu.gateway.framework.websocket;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zsc.edu.gateway.modules.iot.device.service.DeviceService;
import com.zsc.edu.gateway.modules.iot.product.service.ProductService;
import com.zsc.edu.gateway.modules.iot.record.service.RecordDataService;
import lombok.AllArgsConstructor;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.function.Supplier;
@Component
@AllArgsConstructor
public class DataPushTask {
private final DeviceService deviceService;
private final RecordDataService recordDataService;
private final ProductService productService;
private final ObjectMapper objectMapper;
private final DeviceWebSocketHandler deviceHandler;
private final RecordWebSocketHandler recordHandler;
private final ProductWebSocketHandler productHandler;
@Scheduled(fixedRate = 20000)
public void pushData() {
pushAndBroadcast(deviceService::status, deviceHandler);
pushAndBroadcast(recordDataService::recordDataStatus, recordHandler);
pushAndBroadcast(productService::status, productHandler);
}
private void pushAndBroadcast(Supplier<Object> dataSupplier, BaseWebSocketHandler handler) {
try {
Object status = dataSupplier.get();
String json = objectMapper.writeValueAsString(status);
handler.broadcast(json);
} catch (Exception e) {
// 生产环境应使用日志记录
e.printStackTrace();
}
}
}

View File

@ -0,0 +1,37 @@
package com.zsc.edu.gateway.framework.websocket;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(deviceWebSocketHandler(), "/api/rest/ws/device/status")
.setAllowedOrigins("*");
registry.addHandler(recordWebSocketHandler(), "/api/rest/ws/record/status")
.setAllowedOrigins("*");
registry.addHandler(productWebSocketHandler(), "/api/rest/ws/product/status")
.setAllowedOrigins("*");
}
@Bean
public DeviceWebSocketHandler deviceWebSocketHandler() {
return new DeviceWebSocketHandler();
}
@Bean
public RecordWebSocketHandler recordWebSocketHandler() {
return new RecordWebSocketHandler();
}
@Bean
public ProductWebSocketHandler productWebSocketHandler() {
return new ProductWebSocketHandler();
}
}

View File

@ -12,13 +12,13 @@ import java.util.List;
* @author zhuang
*/
public interface RecordDataRepository extends BaseMapper<RecordData> {
@Select("SELECT COUNT(*) FROM iot_record_data WHERE content->>'warning' IS NOT NULL AND (content->>'warning')::int & 1 = 0 AND ((content->>'warning')::int & ~1) > 0")
@Select("SELECT COUNT(*) FROM iot_record_data WHERE content::jsonb->>'warning' IS NOT NULL AND (content::jsonb->>'warning')::int & 1 = 0 AND ((content::jsonb->>'warning')::int & ~1) > 0")
long countWarnings();
@Select("SELECT COUNT(*) FROM iot_record_data WHERE content->>'warning' IS NOT NULL AND ((content->>'warning')::int & 1 = 0) AND ((content->>'warning')::int & ~1) > 0 AND record_time >= #{todayStart}")
@Select("SELECT COUNT(*) FROM iot_record_data WHERE content::jsonb->>'warning' IS NOT NULL AND ((content::jsonb->>'warning')::int & 1 = 0) AND ((content::jsonb->>'warning')::int & ~1) > 0 AND record_time >= #{todayStart}")
long countTodayWarnings(@Param("todayStart") LocalDateTime todayStart);
@Select("SELECT COUNT(*) FROM iot_record_data WHERE content->>'warning' IS NOT NULL AND (content->>'warning')::int & #{bitPosition} = #{bitPosition}")
@Select("SELECT COUNT(*) FROM iot_record_data WHERE content::jsonb->>'warning' IS NOT NULL AND (content::jsonb->>'warning')::int & #{bitPosition} = #{bitPosition}")
long countWarningsByBit(@Param("bitPosition") int bitPosition);
List<RecordData> selectByClientId(@Param("clientId") String clientId);