diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 10bfbe44d..1ad57d27a 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -63,13 +63,14 @@ func (l *Listener) Shutdown(ctx context.Context) error { } func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { + connRemoteAddr := remoteAddr(r) wsConn, err := websocket.Accept(w, r, nil) if err != nil { - log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err) + log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err) return } - rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) + rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr) if err != nil { err = wsConn.Close(websocket.StatusInternalError, "internal error") if err != nil { @@ -90,3 +91,10 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { conn := NewConn(wsConn, lAddr, rAddr) l.acceptFn(conn) } + +func remoteAddr(r *http.Request) string { + if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" { + return r.RemoteAddr + } + return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port")) +}