From 8e14a4fc8f204f8db1e99107c2cbb4d5be34687d Mon Sep 17 00:00:00 2001 From: sousoiki Date: Wed, 10 Mar 2021 17:01:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=8F=92=E4=BB=B6WebSocket?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../basic-example/basic-example-main/pom.xml | 5 + .../basic/example/main/config/WebSocketConfig | 15 + .../plugins/basic-example-plugin3/pom.xml | 26 ++ .../plugins/basic-example-plugin3/readme.md | 10 + .../basic/example/plugin3/DefinePlugin.java | 31 ++ .../com/basic/example/plugin3/WebSocket1.java | 52 +++ .../com/basic/example/plugin3/WebSocket2.java | 53 +++ .../src/main/resources/plugin.properties | 4 + example/basic-example/plugins/pom.xml | 1 + springboot-plugin-framework/pom.xml | 6 + .../starblues/factory/PluginRegistryInfo.java | 13 + .../pipe/classs/PluginClassProcess.java | 1 + .../pipe/classs/group/WebSocketGroup.java | 31 ++ .../post/PluginPostProcessorFactory.java | 1 + .../post/bean/PluginWebSocketProcessor.java | 431 ++++++++++++++++++ 15 files changed, 680 insertions(+) create mode 100644 example/basic-example/basic-example-main/src/main/java/com/basic/example/main/config/WebSocketConfig create mode 100644 example/basic-example/plugins/basic-example-plugin3/pom.xml create mode 100644 example/basic-example/plugins/basic-example-plugin3/readme.md create mode 100644 example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/DefinePlugin.java create mode 100644 example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket1.java create mode 100644 example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket2.java create mode 100644 example/basic-example/plugins/basic-example-plugin3/src/main/resources/plugin.properties create mode 100644 springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/group/WebSocketGroup.java create mode 100644 springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/bean/PluginWebSocketProcessor.java diff --git a/example/basic-example/basic-example-main/pom.xml b/example/basic-example/basic-example-main/pom.xml index 28fb781..24b0743 100644 --- a/example/basic-example/basic-example-main/pom.xml +++ b/example/basic-example/basic-example-main/pom.xml @@ -66,6 +66,11 @@ spring-boot-starter-test test + + + org.springframework.boot + spring-boot-starter-websocket + diff --git a/example/basic-example/basic-example-main/src/main/java/com/basic/example/main/config/WebSocketConfig b/example/basic-example/basic-example-main/src/main/java/com/basic/example/main/config/WebSocketConfig new file mode 100644 index 0000000..8bb35fe --- /dev/null +++ b/example/basic-example/basic-example-main/src/main/java/com/basic/example/main/config/WebSocketConfig @@ -0,0 +1,15 @@ +package com.basic.example.main.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.server.standard.ServerEndpointExporter; + +@Configuration +public class WebSocketConfig { + + @Bean + public ServerEndpointExporter serverEndpointExporter() { + return new ServerEndpointExporter(); + } + +} diff --git a/example/basic-example/plugins/basic-example-plugin3/pom.xml b/example/basic-example/plugins/basic-example-plugin3/pom.xml new file mode 100644 index 0000000..c856c3d --- /dev/null +++ b/example/basic-example/plugins/basic-example-plugin3/pom.xml @@ -0,0 +1,26 @@ + + + + 4.0.0 + + + com.gitee.starblues + basic-example-plugin-parent + 2.4.1-RELEASE + ../pom.xml + + + basic-example-plugin3 + 2.4.1-RELEASE + jar + + + basic-example-plugin3 + com.basic.example.plugin3.DefinePlugin + ${project.version} + sousouki + + + \ No newline at end of file diff --git a/example/basic-example/plugins/basic-example-plugin3/readme.md b/example/basic-example/plugins/basic-example-plugin3/readme.md new file mode 100644 index 0000000..51896ed --- /dev/null +++ b/example/basic-example/plugins/basic-example-plugin3/readme.md @@ -0,0 +1,10 @@ +# 插件WebSocket说明 + +## WebSocket类定义 + WebSocket类只需要使用@ServerEndpoint注解即可,无需其他额外注解 + +## WebSocket访问路径 +### 无参数路径 + ws://ip:port/basic-example-plugin3/test/no_path_param +### 有参数路径 + ws://ip:port/basic-example-plugin3/test/has_path_param/xxx \ No newline at end of file diff --git a/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/DefinePlugin.java b/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/DefinePlugin.java new file mode 100644 index 0000000..5017cb5 --- /dev/null +++ b/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/DefinePlugin.java @@ -0,0 +1,31 @@ +package com.basic.example.plugin3; + +import com.gitee.starblues.realize.BasePlugin; +import org.pf4j.PluginWrapper; + +/** + * 插件定义类 + * + * @author starBlues + * @version 1.0 + */ +public class DefinePlugin extends BasePlugin { + public DefinePlugin(PluginWrapper wrapper) { + super(wrapper); + } + + @Override + protected void startEvent() { + + } + + @Override + protected void deleteEvent() { + + } + + @Override + protected void stopEvent() { + + } +} diff --git a/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket1.java b/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket1.java new file mode 100644 index 0000000..24a7c50 --- /dev/null +++ b/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket1.java @@ -0,0 +1,52 @@ +package com.basic.example.plugin3; + +import java.util.concurrent.atomic.AtomicInteger; +import javax.websocket.OnClose; +import javax.websocket.OnError; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.Session; +import javax.websocket.server.ServerEndpoint; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@ServerEndpoint(value = "/test/no_path_param") +public class WebSocket1 { + + private static final Logger log = LoggerFactory.getLogger(WebSocket1.class); + + private static AtomicInteger onlineCount = new AtomicInteger(0); + + @OnOpen + public void onOpen(Session session) { + onlineCount.incrementAndGet(); // 在线数加1 + log.info("有新连接加入:{},当前在线人数为:{}", session.getId(), onlineCount.get()); + } + + @OnClose + public void onClose(Session session) { + onlineCount.decrementAndGet(); // 在线数减1 + log.info("有一连接关闭:{},当前在线人数为:{}", session.getId(), onlineCount.get()); + } + + @OnMessage + public void onMessage(String message, Session session) { + log.info("服务端收到客户端[{}]的消息:{}", session.getId(), message); + this.sendMessage("Hello, " + message, session); + } + + @OnError + public void onError(Session session, Throwable error) { + log.error("发生错误"); + error.printStackTrace(); + } + + private void sendMessage(String message, Session toSession) { + try { + log.info("服务端给客户端[{}]发送消息{}", toSession.getId(), message); + toSession.getBasicRemote().sendText(message); + } catch (Exception e) { + log.error("服务端发送消息给客户端失败", e); + } + } +} diff --git a/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket2.java b/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket2.java new file mode 100644 index 0000000..e3f56dc --- /dev/null +++ b/example/basic-example/plugins/basic-example-plugin3/src/main/java/com/basic/example/plugin3/WebSocket2.java @@ -0,0 +1,53 @@ +package com.basic.example.plugin3; + +import java.util.concurrent.atomic.AtomicInteger; +import javax.websocket.OnClose; +import javax.websocket.OnError; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.Session; +import javax.websocket.server.PathParam; +import javax.websocket.server.ServerEndpoint; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@ServerEndpoint(value = "/test/has_path_param/{testName}") +public class WebSocket2 { + + private static final Logger log = LoggerFactory.getLogger(WebSocket2.class); + + private static AtomicInteger onlineCount = new AtomicInteger(0); + + @OnOpen + public void onOpen(Session session, @PathParam("testName") String testName) { + onlineCount.incrementAndGet(); // 在线数加1 + log.info("有新连接加入:{},当前在线人数为:{},用户名:{}", session.getId(), onlineCount.get(), testName); + } + + @OnClose + public void onClose(Session session) { + onlineCount.decrementAndGet(); // 在线数减1 + log.info("有一连接关闭:{},当前在线人数为:{}", session.getId(), onlineCount.get()); + } + + @OnMessage + public void onMessage(String message, Session session) { + log.info("服务端收到客户端[{}]的消息:{}", session.getId(), message); + this.sendMessage("Hello, " + message, session); + } + + @OnError + public void onError(Session session, Throwable error) { + log.error("发生错误"); + error.printStackTrace(); + } + + private void sendMessage(String message, Session toSession) { + try { + log.info("服务端给客户端[{}]发送消息{}", toSession.getId(), message); + toSession.getBasicRemote().sendText(message); + } catch (Exception e) { + log.error("服务端发送消息给客户端失败", e); + } + } +} diff --git a/example/basic-example/plugins/basic-example-plugin3/src/main/resources/plugin.properties b/example/basic-example/plugins/basic-example-plugin3/src/main/resources/plugin.properties new file mode 100644 index 0000000..f35c749 --- /dev/null +++ b/example/basic-example/plugins/basic-example-plugin3/src/main/resources/plugin.properties @@ -0,0 +1,4 @@ +plugin.id=basic-example-plugin3 +plugin.class=com.basic.example.plugin3.DefinePlugin +plugin.version=2.4.1-RELEASE +plugin.provider=sousouki \ No newline at end of file diff --git a/example/basic-example/plugins/pom.xml b/example/basic-example/plugins/pom.xml index 146dc36..e22a0b8 100644 --- a/example/basic-example/plugins/pom.xml +++ b/example/basic-example/plugins/pom.xml @@ -13,6 +13,7 @@ basic-example-plugin1 basic-example-plugin2 + basic-example-plugin3 diff --git a/springboot-plugin-framework/pom.xml b/springboot-plugin-framework/pom.xml index 0acd88d..1f5eaf3 100644 --- a/springboot-plugin-framework/pom.xml +++ b/springboot-plugin-framework/pom.xml @@ -151,6 +151,12 @@ test + + org.springframework.boot + spring-boot-starter-websocket + ${spring-boot-version} + provided + diff --git a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/PluginRegistryInfo.java b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/PluginRegistryInfo.java index 338ad1d..3523103 100644 --- a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/PluginRegistryInfo.java +++ b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/PluginRegistryInfo.java @@ -70,6 +70,11 @@ public class PluginRegistryInfo { */ private final Map processorInfo = new ConcurrentHashMap<>(8); + /** + * websocket路径 + */ + private final List websocketPaths = new ArrayList<>(); + private PluginRegistryInfo(PluginWrapper pluginWrapper, PluginManager pluginManager, GenericApplicationContext mainApplicationContext, @@ -323,4 +328,12 @@ public class PluginRegistryInfo { } } + public void addWebsocketPath(String path) { + websocketPaths.add(path); + } + + public List getWebsocketPaths() { + return websocketPaths; + } + } diff --git a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/PluginClassProcess.java b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/PluginClassProcess.java index b3018fe..6e6c03b 100644 --- a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/PluginClassProcess.java +++ b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/PluginClassProcess.java @@ -49,6 +49,7 @@ public class PluginClassProcess implements PluginPipeProcessor { pluginClassGroups.add(new SupplierGroup()); pluginClassGroups.add(new CallerGroup()); pluginClassGroups.add(new OneselfListenerGroup()); + pluginClassGroups.add(new WebSocketGroup()); // 添加扩展 pluginClassGroups.addAll(ExtensionInitializer.getClassGroupExtends()); } diff --git a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/group/WebSocketGroup.java b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/group/WebSocketGroup.java new file mode 100644 index 0000000..cdf3d7b --- /dev/null +++ b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/pipe/classs/group/WebSocketGroup.java @@ -0,0 +1,31 @@ +package com.gitee.starblues.factory.process.pipe.classs.group; + +import com.gitee.starblues.factory.process.pipe.classs.PluginClassGroup; +import com.gitee.starblues.realize.BasePlugin; +import com.gitee.starblues.utils.AnnotationsUtils; +import javax.websocket.server.ServerEndpoint; + +/** + * 分组存在注解: @ServerEndpoint + * + * @author sousouki + */ +public class WebSocketGroup implements PluginClassGroup { + + public static final String GROUP_ID = "websocket"; + + @Override + public String groupId() { + return GROUP_ID; + } + + @Override + public void initialize(BasePlugin basePlugin) { + + } + + @Override + public boolean filter(Class aClass) { + return AnnotationsUtils.haveAnnotations(aClass, false, ServerEndpoint.class); + } +} diff --git a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/PluginPostProcessorFactory.java b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/PluginPostProcessorFactory.java index f3b9781..750b1dc 100644 --- a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/PluginPostProcessorFactory.java +++ b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/PluginPostProcessorFactory.java @@ -32,6 +32,7 @@ public class PluginPostProcessorFactory implements PluginPostProcessor { // 以下顺序不能更改 pluginPostProcessors.add(new PluginInvokePostProcessor(mainApplicationContext)); pluginPostProcessors.add(new PluginControllerPostProcessor(mainApplicationContext)); + pluginPostProcessors.add(new PluginWebSocketProcessor(mainApplicationContext)); // 主要触发启动监听事件,因此在最后一个执行。配合 OneselfListenerStopEventProcessor 该类触发启动、停止事件。 pluginPostProcessors.add(new PluginOneselfStartEventProcessor()); // 添加扩展 diff --git a/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/bean/PluginWebSocketProcessor.java b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/bean/PluginWebSocketProcessor.java new file mode 100644 index 0000000..2e7b0c1 --- /dev/null +++ b/springboot-plugin-framework/src/main/java/com/gitee/starblues/factory/process/post/bean/PluginWebSocketProcessor.java @@ -0,0 +1,431 @@ +package com.gitee.starblues.factory.process.post.bean; + +import com.gitee.starblues.factory.PluginRegistryInfo; +import com.gitee.starblues.factory.process.pipe.classs.group.WebSocketGroup; +import com.gitee.starblues.factory.process.post.PluginPostProcessor; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; +import javax.servlet.ServletContext; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; +import javax.websocket.server.PathParam; +import javax.websocket.server.ServerContainer; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.asm.AnnotationVisitor; +import org.springframework.asm.ClassReader; +import org.springframework.asm.ClassVisitor; +import org.springframework.asm.ClassWriter; +import org.springframework.asm.MethodVisitor; +import org.springframework.asm.Opcodes; +import org.springframework.asm.Type; +import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.socket.server.standard.ServerEndpointExporter; + +/** + * 插件中websocket处理者 + * + * @author sousouki + */ +public class PluginWebSocketProcessor implements PluginPostProcessor { + + private static final Logger log = LoggerFactory.getLogger(PluginWebSocketProcessor.class); + + public static final String KEY = "PluginWsConfigProcessor"; + + private static final int ASM_API_VERSION = Opcodes.ASM7; + + private final ApplicationContext applicationContext; + + public PluginWebSocketProcessor(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + } + + @Override + public void initialize() throws Exception { + + } + + @Override + public void registry(List pluginRegistryInfos) throws Exception { + try { + applicationContext.getBean(ServerEndpointExporter.class); + } catch (BeansException e) { + log.debug("The required bean of {} not found, if you want to use plugin websocket, please create it.", ServerEndpointExporter.class.getName()); + return; + } + pluginRegistryInfos.forEach(pluginRegistryInfo -> { + String pluginId = pluginRegistryInfo.getPluginWrapper().getPluginId(); + if (applicationContext instanceof WebApplicationContext) { + WebApplicationContext webApplicationContext = (WebApplicationContext) applicationContext; + ServletContext servletContext = webApplicationContext.getServletContext(); + if (servletContext == null) { + log.warn("Servlet context is null."); + return; + } + Object obj = servletContext.getAttribute("javax.websocket.server.ServerContainer"); + if (obj instanceof ServerContainer) { + ServerContainer serverContainer = (ServerContainer) obj; + List> websocketClasses = pluginRegistryInfo.getGroupClasses(WebSocketGroup.GROUP_ID); + websocketClasses.forEach(websocketClass -> { + ServerEndpoint serverEndpoint = websocketClass.getDeclaredAnnotation(ServerEndpoint.class); + if (serverEndpoint == null) { + log.warn("WebSocket class {} doesn't has annotation {}", websocketClass.getName(), ServerEndpoint.class.getName()); + return; + } + String websocketPath = serverEndpoint.value(); + UriTemplate uriTemplate; + try { + uriTemplate = new UriTemplate(websocketPath); + } catch (DeploymentException e) { + log.error("Websocket path validate failed.", e); + return; + } + String websocketTemplatePath = uriTemplate.getPath(); + Map pathParam = uriTemplate.getParamMap(); + String newWebsocketPath = "/".concat(pluginId).concat(websocketTemplatePath); + String pluginPath = pluginRegistryInfo.getPluginWrapper().getPluginPath().toString(); + Class proxyServerEndpoint = createProxyClass(pluginRegistryInfo, pluginPath, websocketClass, newWebsocketPath, pathParam); + if (proxyServerEndpoint == null) { + log.warn("Proxy class for websocket class {} is null.", websocketClass.getName()); + return; + } + try { + serverContainer.addEndpoint(proxyServerEndpoint); + pluginRegistryInfo.addWebsocketPath(newWebsocketPath); + log.info("Succeed to create websocket service for path {}", newWebsocketPath); + } catch (DeploymentException e) { + log.error("Create websocket service for websocket class " + websocketClass.getName() + " failed.", e); + } + }); + } + } + }); + } + + @Override + public void unRegistry(List pluginRegistryInfos) throws Exception { + pluginRegistryInfos.forEach(pluginRegistryInfo -> { + List websocketPaths = pluginRegistryInfo.getWebsocketPaths(); + websocketPaths.forEach(websocketPath -> { + try { + if (applicationContext instanceof WebApplicationContext) { + WebApplicationContext webApplicationContext = (WebApplicationContext) applicationContext; + ServletContext servletContext = webApplicationContext.getServletContext(); + if (servletContext == null) { + log.warn("Servlet context is null."); + return; + } + Object obj = servletContext.getAttribute("javax.websocket.server.ServerContainer"); + if (obj instanceof ServerContainer) { + ServerContainer serverContainer = (ServerContainer) obj; + Map configExactMatchMap = (Map) reflectFieldValue(serverContainer, "configExactMatchMap"); + configExactMatchMap.remove(websocketPath); + log.debug("Removed websocket config for path {}", websocketPath); + + Map> configTemplateMatchMap = (Map>) reflectFieldValue(serverContainer, "configTemplateMatchMap"); + configTemplateMatchMap.forEach((key, value) -> { + value.remove(websocketPath); + }); + + Map endpointSessionMap = (Map) reflectParentFieldValue(serverContainer, "endpointSessionMap"); + endpointSessionMap.remove(websocketPath); + log.debug("Removed websocket session for path {}", websocketPath); + + Map sessions = (Map) reflectParentFieldValue(serverContainer, "sessions"); + for (Map.Entry entry : sessions.entrySet()) { + Session session = entry.getKey(); + EndpointConfig endpointConfig = (EndpointConfig) reflectFieldValue(session, "endpointConfig"); + ServerEndpointConfig perEndpointConfig = (ServerEndpointConfig) reflectFieldValue(endpointConfig, "perEndpointConfig"); + String path = (String) reflectFieldValue(perEndpointConfig, "path"); + if (path.equals(websocketPath)) { + session.close(); + log.debug("Closed websocket session {} for path {}", session.getId(), websocketPath); + sessions.remove(session); + log.debug("Removed websocket session {} for path {}", session.getId(), websocketPath); + } + } + log.info("Remove websocket for path {} success.", websocketPath); + } + } + } catch (IllegalAccessException | NoSuchFieldException | IOException e) { + log.error("Remove websocket failed for path " + websocketPath, e); + } + }); + }); + } + + private Object reflectFieldValue(Object obj, String fieldName) throws NoSuchFieldException, IllegalAccessException { + Field field = obj.getClass().getDeclaredField(fieldName); + if (!field.isAccessible()) { + field.setAccessible(true); + } + return field.get(obj); + } + + private Object reflectParentFieldValue(Object obj, String fieldName) throws NoSuchFieldException, IllegalAccessException { + Field field = obj.getClass().getSuperclass().getDeclaredField(fieldName); + if (!field.isAccessible()) { + field.setAccessible(true); + } + return field.get(obj); + } + + private Class createProxyClass(PluginRegistryInfo pluginRegistryInfo, String pluginPath, Class websocketClass, String newWebsocketPath, Map pathParam) { + String simpleName = websocketClass.getSimpleName(); + String className = websocketClass.getName(); + String basePackage = className.substring(0, className.lastIndexOf(simpleName) - 1); + try (JarFile jarFile = new JarFile(pluginPath)) { + Enumeration jarEntries = jarFile.entries(); + while (jarEntries.hasMoreElements()) { + JarEntry entry = jarEntries.nextElement(); + String jarEntryName = entry.getName(); + if (jarEntryName.endsWith(simpleName.concat(".class")) && jarEntryName.replaceAll("/", ".").startsWith(basePackage)) { + InputStream inputStream = jarFile.getInputStream(entry); + if (inputStream == null) { + log.warn("Class stream for {} is null.", websocketClass.getName()); + return null; + } + Class proxyClass = createProxyClass(pluginRegistryInfo, inputStream, websocketClass, newWebsocketPath, pathParam); + log.debug("Created proxy class {} for websocket class {}", proxyClass.getName(), className); + return proxyClass; + } + } + } catch (IOException | NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + log.error("Create proxy class for websocket class " + className + "error", e); + return null; + } + return null; + } + + private Class createProxyClass(PluginRegistryInfo pluginRegistryInfo, InputStream inputStream, Class websocketClass, String newWebsocketPath, Map pathParam) throws IllegalAccessException, NoSuchMethodException, InvocationTargetException, IOException { + ClassReader cr = new ClassReader(inputStream); + ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + String className = websocketClass.getName(); + String proxyClassName = className.concat("$PluginProxyServerEndpoint$".concat(String.valueOf(System.currentTimeMillis()))); + ClassVisitor cv = new ServerEndpointProxyClassVisitor(ASM_API_VERSION, cw, websocketClass, proxyClassName, newWebsocketPath, pathParam); + cr.accept(cv, 0); + byte[] classData = cw.toByteArray(); + return defineClass(pluginRegistryInfo, proxyClassName, classData); + } + + private Class defineClass(PluginRegistryInfo pluginRegistryInfo, String proxyClassName, byte[] classData) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException { + ClassLoader classLoader = pluginRegistryInfo.getPluginClassLoader(); + Method defineClassMethod = ClassLoader.class.getDeclaredMethod("defineClass", String.class, byte[].class, int.class, int.class); + if (!defineClassMethod.isAccessible()) { + defineClassMethod.setAccessible(true); + } + return (Class) defineClassMethod.invoke(classLoader, proxyClassName, classData, 0, classData.length); + } + + /** + * 修改类名及类注解的值 + */ + private static class ServerEndpointProxyClassVisitor extends ClassVisitor { + + private final String proxyClassName; + private final String proxyClassInternalName; + private final String classInternalName; + private final String newPath; + private final Class websocketClass; + private final Map pathParam; + + private ServerEndpointProxyClassVisitor(int api, ClassVisitor classVisitor, Class websocketClass, String proxyClassName, String newPath, Map pathParam) { + super(api, classVisitor); + String classInternalName = Type.getInternalName(websocketClass); + this.proxyClassName = proxyClassName; + this.proxyClassInternalName = proxyClassName.replaceAll("\\.", "/"); + this.classInternalName = classInternalName; + this.newPath = newPath; + this.websocketClass = websocketClass; + this.pathParam = pathParam; + } + + @Override + public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { + super.visit(version, access, proxyClassInternalName, signature, superName, interfaces); + log.debug("Changed class name from {} to {}", classInternalName, proxyClassInternalName); + } + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + AnnotationVisitor superAnnotationVisitor = super.visitAnnotation(descriptor, visible); + if (descriptor.equals(Type.getDescriptor(ServerEndpoint.class))) { + return new AnnotationVisitor(ASM_API_VERSION, superAnnotationVisitor) { + @Override + public void visit(String name, Object value) { + if ("value".equals(name)) { + value = newPath; + log.debug("Changed websocket path from {} to {} in for annotation {}", value, newPath, ServerEndpoint.class.getName()); + } + super.visit(name, value); + } + }; + } + return superAnnotationVisitor; + } + + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { + MethodVisitor superMethodVisitor = super.visitMethod(access, name, descriptor, signature, exceptions); + return new ServerEndpointProxyMethodVisitor(ASM_API_VERSION, superMethodVisitor, websocketClass, proxyClassName, pathParam); + } + + } + + /** + * 替换成员变量及方法所属 + */ + private static class ServerEndpointProxyMethodVisitor extends MethodVisitor { + + private final String classInternalName; + private final String proxyClassInternalName; + private final Map pathParam; + + private ServerEndpointProxyMethodVisitor(int api, MethodVisitor methodVisitor, Class websocketClass, String proxyClassName, Map pathParam) { + super(api, methodVisitor); + String classInternalName = Type.getInternalName(websocketClass); + this.proxyClassInternalName = proxyClassName.replaceAll("\\.", "/"); + this.classInternalName = classInternalName; + this.pathParam = pathParam; + } + + @Override + public void visitFieldInsn(int opcode, String owner, String name, String descriptor) { + // 替换成员变量所属 + if (owner.equals(classInternalName)) { + super.visitFieldInsn(opcode, proxyClassInternalName, name, descriptor); + log.debug("Changed owner from {} to {} for field {}", classInternalName, proxyClassInternalName, name); + } else { + super.visitFieldInsn(opcode, owner, name, descriptor); + } + } + + @Override + public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) { + // 替换方法所属 + if (owner.equals(classInternalName)) { + super.visitMethodInsn(opcode, proxyClassInternalName, name, descriptor, isInterface); + log.debug("Changed owner from {} to {} in method {}", classInternalName, proxyClassInternalName, name); + } else { + super.visitMethodInsn(opcode, owner, name, descriptor, isInterface); + } + } + + @Override + public AnnotationVisitor visitParameterAnnotation(int parameter, String descriptor, boolean visible) { + AnnotationVisitor superAnnotationVisitor = super.visitParameterAnnotation(parameter, descriptor, visible); + // 替换@PathParam注解中value的值 + if (descriptor.equals(Type.getDescriptor(PathParam.class))) { + return new ServerEndpointProxyParameterAnnotationVisitor(ASM_API_VERSION, superAnnotationVisitor, pathParam); + } + return super.visitParameterAnnotation(parameter, descriptor, visible); + } + + } + + /** + * 将@PathParam("name")替换为@PathParam("0"),使其与uri对应 + */ + private static class ServerEndpointProxyParameterAnnotationVisitor extends AnnotationVisitor { + + private final Map pathParam; + + private ServerEndpointProxyParameterAnnotationVisitor(int api, AnnotationVisitor annotationVisitor, Map pathParam) { + super(api, annotationVisitor); + this.pathParam = pathParam; + } + + @Override + public void visit(String name, Object value) { + if ("value".equals(name)) { + Integer index = pathParam.get(String.valueOf(value)); + if (index != null) { + value = String.valueOf(index); + log.debug("Changed path parameter {} to {}", value, index); + } + } + super.visit(name, value); + } + } + + /** + * websocket路径解析类,主要用于处理参数 + */ + private static class UriTemplate { + + private final Map paramMap = new ConcurrentHashMap<>(); + private final String path; + + private UriTemplate(String path) throws DeploymentException { + if (path == null || path.length() == 0 || !path.startsWith("/") || path.contains("/../") || path.contains("/./") || path.contains("//")) { + throw new DeploymentException(String.format("The path [%s] is not valid.", path)); + } + StringBuilder normalized = new StringBuilder(path.length()); + Set paramNames = new HashSet<>(); + + // Include empty segments. + String[] segments = path.split("/", -1); + int paramCount = 0; + + for (int i = 0; i < segments.length; i++) { + String segment = segments[i]; + if (segment.length() == 0) { + if (i == 0 || (i == segments.length - 1 && paramCount == 0)) { + // Ignore the first empty segment as the path must always + // start with '/' + // Ending with a '/' is also OK for instances used for + // matches but not for parameterised templates. + continue; + } else { + // As per EG discussion, all other empty segments are + // invalid + throw new DeploymentException(String.format("The path [%s] contains one or more empty segments which is not permitted", path)); + } + } + normalized.append('/'); + if (segment.startsWith("{") && segment.endsWith("}")) { + segment = segment.substring(1, segment.length() - 1); + normalized.append('{'); + normalized.append(paramCount++); + normalized.append('}'); + if (!paramNames.add(segment)) { + throw new DeploymentException(String.format("The parameter [%s] appears more than once in the path which is not permitted", segment)); + } + paramMap.put(segment, paramCount - 1); + } else { + if (segment.contains("{") || segment.contains("}")) { + throw new DeploymentException(String.format("The segment [%s] is not valid in the provided path [%s]", segment, path)); + } + normalized.append(segment); + } + } + this.path = normalized.toString(); + } + + public String getPath() { + return path; + } + + public Map getParamMap() { + return paramMap; + } + } + +} -- Gitee