diff --git a/graffiti/websocket/client.go b/graffiti/websocket/client.go index 5192b4cbee5eaf9a9ab4b29a2d2cc34a7c04faf1..5b824cd08bb65004af82faac314fa6dad999e97b 100644 --- a/graffiti/websocket/client.go +++ b/graffiti/websocket/client.go @@ -199,3 +199,408 @@ func (d *DefaultSpeakerEventHandler) OnConnected(c Speaker) error { // OnDisconnected is called when the connection is closed or lost. func (d *DefaultSpeakerEventHandler) OnDisconnected(c Speaker) { } + +// GetHost returns the hostname/host-id of the connection. +func (c *Conn) GetHost() string { + return c.Host +} + +// GetAddrPort returns the address and the port of the remote end. +func (c *Conn) GetAddrPort() (string, int) { + return c.Addr, c.Port +} + +// GetURL returns the URL of the connection +func (c *Conn) GetURL() *url.URL { + return c.URL +} + +// IsConnected returns the connection status. +func (c *Conn) IsConnected() bool { + return c.State.Load() == service.RunningState +} + +// GetStatus returns the status of a WebSocket connection +func (c *Conn) GetStatus() ConnStatus { + c.RLock() + defer c.RUnlock() + + status := c.ConnStatus + status.State = new(ConnState) + *status.State = ConnState(c.State.Load()) + return c.ConnStatus +} + +// SpeakerStructMessageHandler interface used to receive Struct messages. +type SpeakerStructMessageHandler interface { + OnStructMessage(c Speaker, m *StructMessage) +} + +// SendMessage adds a message to sending queue. +func (c *Conn) SendMessage(m Message) error { + if !c.IsConnected() { + return errors.New("Not connected") + } + + b, err := m.Bytes(c.GetClientProtocol()) + if err != nil { + return err + } + + c.send <- b + + return nil +} + +// SendRaw adds raw bytes to sending queue. +func (c *Conn) SendRaw(b []byte) error { + if !c.IsConnected() { + return errors.New("Not connected") + } + + c.send <- b + + return nil +} + +// GetServiceType returns the client type. +func (c *Conn) GetServiceType() service.Type { + return c.ServiceType +} + +// GetClientProtocol returns the websocket protocol. +func (c *Conn) GetClientProtocol() Protocol { + return c.ClientProtocol +} + +// GetHeaders returns the client HTTP headers. +func (c *Conn) GetHeaders() http.Header { + return c.Headers +} + +// GetRemoteHost returns the hostname/host-id of the remote side of the connection. +func (c *Conn) GetRemoteHost() string { + return c.RemoteHost +} + +// GetRemoteServiceType returns the remote service type. +func (c *Conn) GetRemoteServiceType() service.Type { + return c.RemoteServiceType +} + +// write sends a message directly over the wire. +func (c *Conn) write(msg []byte) error { + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + c.conn.EnableWriteCompression(c.writeCompression) + w, err := c.conn.NextWriter(c.messageType) + if err != nil { + return err + } + + if _, err = w.Write(msg); err != nil { + return err + } + + return w.Close() +} + +// Run the main loop +func (c *Conn) Run() { + c.wg.Add(2) + c.run() +} + +// Start main loop in a goroutine +func (c *Conn) Start() { + c.wg.Add(2) + go c.run() +} + +func (c *Conn) cloneEventHandlers() (handlers []SpeakerEventHandler) { + c.RLock() + handlers = append(handlers, c.eventHandlers...) + c.RUnlock() + + return +} + +// main loop to read and send messages +func (c *Conn) run() { + read := make(chan []byte, c.queueSize) + + flushChannel := func(c chan []byte, cb func(msg []byte) error) error { + for { + select { + case m := <-c: + if len(m) == 0 { + break + } + if err := cb(m); err != nil { + return err + } + default: + return nil + } + } + } + + // notify all the listeners that a message was received + handleReceivedMessage := func(m []byte) error { + for _, l := range c.cloneEventHandlers() { + l.OnMessage(c.wsSpeaker, RawMessage(m)) + } + return nil + } + + var readWg sync.WaitGroup + readWg.Add(1) + + // goroutine to read messages from the socket and put them into a channel + go func() { + defer readWg.Done() + + for c.running.Load() == true { + _, m, err := c.conn.ReadMessage() + if err != nil { + if c.running.Load() != false { + c.quit <- true + } + break + } + read <- m + } + }() + + defer func() { + c.conn.Close() + c.State.Store(service.StoppedState) + + // handle all the pending received messages + readWg.Wait() + flushChannel(read, handleReceivedMessage) + close(read) + + for _, l := range c.cloneEventHandlers() { + l.OnDisconnected(c.wsSpeaker) + } + + c.wg.Done() + }() + + go func() { + defer c.wg.Done() + + for m := range read { + handleReceivedMessage(m) + } + }() + + for { + select { + case <-c.quit: + return + case m := <-c.send: + if err := c.write(m); err != nil { + c.logger.Errorf("Error while sending message to %+v: %s", c, err) + return + } + case <-c.flush: + if err := flushChannel(c.send, c.write); err != nil { + c.logger.Errorf("Error while flushing send queue for %+v: %s", c, err) + return + } + case <-c.pingTicker.C: + if err := c.sendPing(); err != nil { + c.logger.Errorf("Error while sending ping to %+v: %s", c, err) + + // stop the ticker and request a quit + c.pingTicker.Stop() + return + } + } + } +} + +// sendPing is used for remote connections by the server to send PingMessage +// to remote client. +func (c *Conn) sendPing() error { + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + return c.conn.WriteMessage(websocket.PingMessage, []byte{}) +} + +// AddEventHandler registers a new event handler +func (c *Conn) AddEventHandler(h SpeakerEventHandler) { + c.Lock() + c.eventHandlers = append(c.eventHandlers, h) + c.Unlock() +} + +// Connect default implementation doing nothing as for incoming connection it is not used. +func (c *Conn) Connect(context.Context) error { + return nil +} + +// Flush all the pending sent messages +func (c *Conn) Flush() { + c.flush <- struct{}{} +} + +// Stop disconnect the speaker +func (c *Conn) Stop() { + c.running.Store(false) + if c.State.CompareAndSwap(service.RunningState, service.StoppingState) { + c.quit <- true + } +} + +// StopAndWait disconnect the speaker and wait for the goroutine to end +func (c *Conn) StopAndWait() { + c.Stop() + c.wg.Wait() +} + +func newConn(host string, clientType service.Type, clientProtocol Protocol, url *url.URL, headers http.Header, opts ClientOpts) *Conn { + if headers == nil { + headers = http.Header{} + } + + if opts.QueueSize == 0 { + opts.QueueSize = defaultQueueSize + } + + port, _ := strconv.Atoi(url.Port()) + c := &Conn{ + ConnStatus: ConnStatus{ + Host: host, + ServiceType: clientType, + ClientProtocol: clientProtocol, + Addr: url.Hostname(), + Port: port, + State: new(ConnState), + URL: url, + Headers: headers, + ConnectTime: time.Now(), + }, + send: make(chan []byte, opts.QueueSize), + queueSize: opts.QueueSize, + flush: make(chan struct{}), + quit: make(chan bool, 2), + pingTicker: &time.Ticker{}, + writeCompression: opts.WriteCompression, + logger: opts.Logger, + } + + if clientProtocol == JSONProtocol { + c.messageType = websocket.TextMessage + } else { + c.messageType = websocket.BinaryMessage + } + + c.State.Store(service.StoppedState) + c.running.Store(true) + return c +} + +func (c *Client) scheme() string { + if c.TLSConfig != nil { + return "wss://" + } + return "ws://" +} + +// Connect to the server +func (c *Client) Connect(ctx context.Context) error { + var err error + endpoint := c.URL.String() + headers := http.Header{ + "X-Host-ID": {c.Host}, + "X-Client-Type": {c.ServiceType.String()}, + "X-Client-Protocol": {c.ClientProtocol.String()}, + "X-Websocket-Namespace": {WildcardNamespace}, + } + + for k, v := range c.Headers { + headers[k] = v + } + + c.Opts.Logger.Infof("Connecting to %s", endpoint) + + if c.AuthOpts != nil { + shttp.SetAuthHeaders(&headers, c.AuthOpts) + } + + d := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + d.TLSClientConfig = c.TLSConfig + + var resp *http.Response + c.conn, resp, err = d.DialContext(ctx, endpoint, headers) + if err != nil { + return fmt.Errorf("Unable to create a WebSocket connection %s : %s", endpoint, err) + } + + if c.RemoteHost = resp.Header.Get("X-Host-ID"); c.RemoteHost == "" { + c.conn.Close() + return errors.New("X-Host-ID header not provided") + } + + if c.RemoteServiceType = service.Type(resp.Header.Get("X-Service-Type")); c.RemoteServiceType == "" { + c.conn.Close() + return errors.New("X-Service-Type header not provided") + } + + c.conn.SetPingHandler(nil) + c.conn.EnableWriteCompression(c.writeCompression) + + c.State.Store(service.RunningState) + + c.Opts.Logger.Infof("Connected to %s", endpoint) + + // notify connected + for _, l := range c.cloneEventHandlers() { + if err := l.OnConnected(c); err != nil { + c.Stop() + return err + } + } + + return nil +} + +// Start connects to the server - and reconnect if necessary +func (c *Client) Start() { + go func() { + for c.running.Load() == true { + if err := c.Connect(context.Background()); err == nil { + c.Run() + if c.running.Load() == true { + c.wg.Wait() + } + } else { + c.Opts.Logger.Error(err) + } + time.Sleep(1 * time.Second) + } + }() +} + +// NewClient returns a Client with a new connection. +func NewClient(host string, clientType service.Type, url *url.URL, opts ClientOpts) *Client { + if opts.Logger == nil { + opts.Logger = logging.GetLogger() + } + + wsconn := newConn(host, clientType, opts.Protocol, url, opts.Headers, opts) + c := &Client{ + Conn: wsconn, + AuthOpts: opts.AuthOpts, + TLSConfig: opts.TLSConfig, + Opts: opts, + } + + wsconn.wsSpeaker = c + return c +}