04-websocket并发场景下发送消息出错的解决方案

04-websocket并发场景下发送消息出错的解决方案

前言:上一篇文章,主要演示了websocket并发场景下发送消息出错的问题,本文首先填上一篇的埋下的坑,也会给出解决方案

1 填坑-为什么调用的是 getBasicRemote().sendText方法

1.1 getBasicRemote().sendText 与 getAsyncRemote().sendText

上一篇提到,spring封装的websocket在发送消息的时候,调用的是javax.websocket的getBasicRemote().sendText方法,但是javax.websocket是支持异步的,因为它提供了异步发送消息的方法

...
    //org.apache.tomcat.websocket;
    //WsSession类
@Override
    public RemoteEndpoint.Async getAsyncRemote() {
        checkState();
        return remoteEndpointAsync;
    }
...
  	//org.apache.tomcat.websocket;
    //WsRemoteEndpointAsync类
    @Override
    public Future<Void> sendText(String text) {
        return base.sendStringByFuture(text);
    }

...
    //代码示例
    session.getAsyncRemote().sendText(content);
                  

那为什么spring内部在封装websocket的时候,没有是使用这个异步调用的方法呢?而是采用了基于同步发送的方法。如果采用的是基于异步调用的方法,可以避免并发出错的问题吗?

首先,javax.websocket提供的异步调用的方法,是有特定的条件的,特定条件就是:如果某一时刻来的100消息,他们对应的客户端都是不用一样的,也就是一条消息对应一个特定的客户端,那么在这种并发情况下,使用这个异步调用的方法是没有任何问题的,因为不同的客户端对应的session都不一样,里面的状态机也不一样,都是独立的,不会互相影响,没有共享数据,自然也就没有并发安全问题。

但是 ,但是,在实际的业务中,来了100条消息,很有可能其中某几条,比如10条是发给同一个客户端的,这个时候就涉及到共享同一个session的状态了,就会出现并发场景下,发送消息出错,也就是说,有的消息,能正常 发送出去,客户端能接收到消息,但是有的客户端收不到消息,或者只收到了几条消息,没有收到全部的消息。显然,这就是问题所在

结论:由于实际的业务场景中,一条消息并不能对应一个唯一的客户端(一对一),并且这种对应关系是十分复杂的,是很容易出错的(除非把所有的消息按照对应的客户端分组,每个分组一个线程,不同的分组不同的线程…没错,这已经是一种解决方案了!!!),所以虽然javax.websocket本身提供了异步发送消息的方法,但是spring并没有采用,而是采用的同步调用的方法。

另外一个问题:虽然spring用的是同步调用的方法,但是还是会出现并发发送消息出错的情况,为什么呢?因为这个同步调用并没有考虑并发,也就是没有使用synchronized等手段来保证并发同步,所以,即使是javax.websocket的同步调用方法,在并发场景下,还是会出错。

2 解决方案

2.1 synchronized锁住发送消息的方法

这种解决方案效果是很明显的,道理也是显而易见的,只是这样的话,全局的消息都会阻塞住,某些场景下性能会十分槽糕。需要结合实际的业务场景考虑,是否适合采用这种解决方案

spring框架下使用websocket

... 
if (session.isOpen()) {
                synchronized (session){
                    try {
                        session.sendMessage(new TextMessage(message));
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                }
            }
...

javax.websocket

 public synchronized void sendMessage(String id, String content) {
        Session session = (Session) clients.get(id);
        if (session == null) {
            log.error("服务端给客户端发送消息失败 ==> toid = {} 不存在, content = {}", id, content);
        } else {

            try {
                session.getBasicRemote().sendText(content);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }

            log.info("服务端给客户端发送消息 ==> toid = {}, content = {}", id, content);
        }
    }

2.2 spring websocketSession的加强版

org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator

这个目前是最优解,兼顾并发和性能。不过仅限于spring框架下的websocket,如果是使用javax下的websocket,需要自己按照spring的解决方案封装一些东西(把spring的源码复制过来,简化即可使用),可以参考我的代码

使用方法就是:用ConcurrentWebSocketSessionDecorator替换websocketSession

@Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        // 当WebSocket连接建立成功时调用

        // 获取url指定的参数信息:ws://127.0.0.1:10010/ws?scanPoint=01&userId=123
        //String scanPoint = extractParams(session, "scanPoint");
        //if (Objects.isNull(scanPoint)) {
        //    return;
        //}
        
      	//原来是WebSocketSession,换成ConcurrentWebSocketSessionDecorator
        //sessions.put(scanPoint, session);
        //ConcurrentWebSocketSessionDecorator的构造方法后面2个参数的含义是:发送超时时间限制,发送内容大小的限制
        //看到这个你就会发现,不愧是spring啊,强的一批
        sessions.put(scanPoint, new ConcurrentWebSocketSessionDecorator(session, 10*1000, 10*1024));
      	//...

    }

这里先埋个坑,后面打算写一篇文章来解析spring的ConcurrentWebSocketSessionDecorator

javax.websocket

(用到了spring的其他依赖,Component的什么的)

下面是我把spring源码复制过来,简化后的代码。sendMessage是对外提供的发送消息的方法

import java.io.IOException;
import java.util.Objects;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

@ServerEndpoint("/ws/{id}")
@Component
public class MyWebsocketDecorator {


    private static final Logger log = LoggerFactory.getLogger(MyWebsocketDecorator.class);
    private static final AtomicInteger onlineClientCount = new AtomicInteger(0);
    private static final ConcurrentMap<String, Session> clients = new ConcurrentHashMap<>();

    private String id;
    private Session session;
    //=======================================================
    /**
     * 发送时间限制:毫秒
     */
    private final int sendTimeLimit = 10 * 1000;

    /**
     * websocket的缓存消息队列:所有的消息全部先放入这里:我这里写3000足够了
     */
    private final Queue<String> buffer = new LinkedBlockingQueue<>(3000);
    /**
     * 开始发送的时间戳
     */
    private volatile long sendStartTime;
    /**
     * 是否达到限制条件
     */
    private volatile boolean limitExceeded;
    /**
     * 关闭websocket
     */
    private volatile boolean closeInProgress;
    /**
     * 发送消息时需要获取的锁
     */
    private final Lock flushLock = new ReentrantLock();
    /**
     * 检查websocket状态时需要获取的锁
     */
    private final Lock closeLock = new ReentrantLock();


    @OnOpen
    public void open(@PathParam("id") String id, Session session) {
        this.id = id;
        this.session = session;
        clients.put(id, session);
        onlineClientCount.incrementAndGet();

        log.info("连接建立成功,当前在线数为:{} ==> 开始监听新连接:session_id = {}, id = {}", new Object[]{onlineClientCount, session.getId(), id});
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        log.info("服务端接收到客户端消息 ==> id = {}, content = {}", this.id, message);
    }

    @OnClose
    public void close(@PathParam("id") String id, Session session) {
        clients.remove(id);
        onlineClientCount.decrementAndGet();
        log.info("连接关闭成功,当前在线数为:{} ==> 关闭该连接信息:session_id = {}, id = {}", new Object[]{onlineClientCount, session.getId(), id});
    }
	
    	
       public void sendMessage(String id, String message) throws IOException {
        //把id和message拼接在一起,发消息的时候,在拆开
        this.buffer.add(id + "|" + message);

        do {
            if (!tryFlushMessageBuffer()) {
                //if (logger.isTraceEnabled()) {
                //    logger.trace(String.format("Another send already in progress: " +
                //                    "session id '%s':, "in-progress" send time %d (ms), buffer size %d bytes",
                //            getId(), getTimeSinceSendStarted(), getBufferSize()));
                //}
                log.info("================>有线程正在发送消息,当前线程检查是否超时!");
                checkSessionLimits();
                break;
            }
        }
        while (!this.buffer.isEmpty());
    }
    

    public void sendSynchronize(String id, String content) {
        Session session = (Session) clients.get(id);
        if (session == null) {
            log.error("服务端给客户端发送消息失败 ==> toid = {} 不存在, content = {}", id, content);
        } else {
            try {
                session.getBasicRemote().sendText(content);
            } catch (Exception e) {
                log.info("异常信息:{}", e.getMessage());
                log.error("服务端给客户端发送消息失败 ==> toid = {}, content = {}", id, content);
            }
            log.info("服务端给客户端发送消息 ==> toid = {}, content = {}", id, content);
        }
    }

    public void sendSynchronize(String content) {
        clients.forEach((onlineid, session) -> {
            if (!this.id.equalsIgnoreCase(onlineid)) {
                try {
                    session.getBasicRemote().sendText(content);
                    log.info("服务端给客户端群发消息 ==> id = {}, toid = {}, content = {}", new Object[]{this.id, onlineid, content});
                } catch (Exception e) {
                    log.info("异常信息:{}", e.getMessage());
                    log.error("服务端给客户端发送消息失败 ==> toid = {}, content = {}", id, content);
                }
            }

        });
    }

 


    private boolean tryFlushMessageBuffer() throws IOException {
        if (this.flushLock.tryLock()) {
            try {

                while (true) {

                    String message = this.buffer.poll();
                    if (message == null) {
                        break;
                    }
                    this.sendStartTime = System.currentTimeMillis();
                    //发送消息
                    String[] split = message.split("\|");
                    String key = split[0];
                    if (split.length != 2) {
                        sendSynchronize(key);
                    } else {
                        sendSynchronize(key, split[1]);
                    }

                    this.sendStartTime = 0;

                }

            } finally {
                this.sendStartTime = 0;
                this.flushLock.unlock();
            }
            return true;
        }
        return false;
    }


    private void checkSessionLimits() {
        if (this.closeLock.tryLock()) {
            try {
                if (getTimeSinceSendStarted() > getSendTimeLimit()) {
                    //String format = "Send time %d (ms) for session '%s' exceeded the allowed limit %d";
                    //String reason = String.format(format, getTimeSinceSendStarted(), getId(), getSendTimeLimit());
                    //limitExceeded(reason);
                    //超时异常处理
                    throw new RuntimeException("ws消息超时");
                }

            } finally {
                this.closeLock.unlock();
            }
        }
    }

    public long getTimeSinceSendStarted() {
        long start = this.sendStartTime;
        return (start > 0 ? (System.currentTimeMillis() - start) : 0);
    }

    public int getSendTimeLimit() {
        return this.sendTimeLimit;
    }


}

这个是原始代码,没有引入spring的解决方案

import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

@ServerEndpoint("/ws/{id}")
@Component
public class WS {
    private static final Logger log = LoggerFactory.getLogger(WS.class);
    private static AtomicInteger onlineClientCount = new AtomicInteger(0);
    private static final ConcurrentMap<String, Session> clients = new ConcurrentHashMap<>();
    private String id;
    private Session session;

    public WS() {
    }

    @OnOpen
    public void open(@PathParam("id") String id, Session session) {
        this.id = id;
        this.session = session;
        clients.put(id, session);
        onlineClientCount.incrementAndGet();
        log.info("连接建立成功,当前在线数为:{} ==> 开始监听新连接:session_id = {}, id = {},。", new Object[]{onlineClientCount, session.getId(), id});
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        log.info("服务端接收到客户端消息 ==> id = {}, content = {}", this.id, message);
    }

    @OnClose
    public void close(@PathParam("id") String id, Session session) {
        clients.remove(id);
        onlineClientCount.decrementAndGet();
        log.info("连接关闭成功,当前在线数为:{} ==> 关闭该连接信息:session_id = {}, id = {},。", new Object[]{onlineClientCount, session.getId(), id});
    }

    public void send(String id, String content) {
        Session session = (Session)clients.get(id);
        if (session == null) {
            log.error("服务端给客户端发送消息 ==> toid = {} 不存在, content = {}", id, content);
        } else {
            session.getAsyncRemote().sendText(content);
            log.info("服务端给客户端发送消息 ==> toid = {}, content = {}", id, content);
        }
    }

    public void send(String content) {
        clients.forEach((onlineid, session) -> {
            if (!this.id.equalsIgnoreCase(onlineid)) {
                session.getAsyncRemote().sendText(content);
                log.info("服务端给客户端群发消息 ==> id = {}, toid = {}, content = {}", new Object[]{this.id, onlineid, content});
            }

        });
    }
}

2.3 其他博主的解决方案

刚才在上面已经提到了,不同的客户端不同的线程,已经有博主实现了,我就不造轮子了(后面有时间的话,看看能不能造一个,手动狗头)

这是文章地址:https://blog.csdn.net/qq_35634154/article/details/122576665

2.4 一些想法,还未实践

  1. 引入生产者和消费者队列,生产者只管往队列里面翻消息,消费者判断消息队列是否为空,不为空就取出消息,发送

  2. 引入事件驱动队列

    不过,感觉都有些过度解决问题了。。。

3 总结

本文主要给出了websocket的在并发场景下发送消息出错的几种解决方案,有些方案仅限于思路,并未实现


参考

https://blog.csdn.net/qq_35634154/article/details/122576665

https://blog.csdn.net/abu935009066/article/details/131218149