re-added websockify

This commit is contained in:
olevole
2023-02-26 23:47:46 +03:00
parent 28f28772df
commit d85b40fc35
16 changed files with 3188 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from websockify.websocket import *
from websockify.websocketproxy import *

View File

@@ -0,0 +1,4 @@
import websockify
if __name__ == '__main__':
websockify.websocketproxy.websockify_init()

View File

@@ -0,0 +1,102 @@
class BasePlugin():
def __init__(self, src=None):
self.source = src
def authenticate(self, headers, target_host, target_port):
pass
class AuthenticationError(Exception):
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
self.code = response_code
self.headers = response_headers
self.msg = response_msg
if log_msg is None:
log_msg = response_msg
super().__init__('%s %s' % (self.code, log_msg))
class InvalidOriginError(AuthenticationError):
def __init__(self, expected, actual):
self.expected_origin = expected
self.actual_origin = actual
super().__init__(
response_msg='Invalid Origin',
log_msg="Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
class BasicHTTPAuth():
"""Verifies Basic Auth headers. Specify src as username:password"""
def __init__(self, src=None):
self.src = src
def authenticate(self, headers, target_host, target_port):
import base64
auth_header = headers.get('Authorization')
if auth_header:
if not auth_header.startswith('Basic '):
self.auth_error()
try:
user_pass_raw = base64.b64decode(auth_header[6:])
except TypeError:
self.auth_error()
try:
# http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication
user_pass_as_text = user_pass_raw.decode('ISO-8859-1')
except UnicodeDecodeError:
self.auth_error()
user_pass = user_pass_as_text.split(':', 1)
if len(user_pass) != 2:
self.auth_error()
if not self.validate_creds(*user_pass):
self.demand_auth()
else:
self.demand_auth()
def validate_creds(self, username, password):
if '%s:%s' % (username, password) == self.src:
return True
else:
return False
def auth_error(self):
raise AuthenticationError(response_code=403)
def demand_auth(self):
raise AuthenticationError(response_code=401,
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
class ExpectOrigin():
def __init__(self, src=None):
if src is None:
self.source = []
else:
self.source = src.split()
def authenticate(self, headers, target_host, target_port):
origin = headers.get('Origin', None)
if origin is None or origin not in self.source:
raise InvalidOriginError(expected=self.source, actual=origin)
class ClientCertCNAuth():
"""Verifies client by SSL certificate. Specify src as whitespace separated list of common names."""
def __init__(self, src=None):
if src is None:
self.source = []
else:
self.source = src.split()
def authenticate(self, headers, target_host, target_port):
if headers.get('SSL_CLIENT_S_DN_CN', None) not in self.source:
raise AuthenticationError(response_code=403)

View File

@@ -0,0 +1,118 @@
import logging.handlers as handlers, socket, os, time
class WebsockifySysLogHandler(handlers.SysLogHandler):
"""
A handler class that sends proper Syslog-formatted messages,
as defined by RFC 5424.
"""
_legacy_head_fmt = '<{pri}>{ident}[{pid}]: '
_rfc5424_head_fmt = '<{pri}>1 {timestamp} {hostname} {ident} {pid} - - '
_head_fmt = _rfc5424_head_fmt
_legacy = False
_timestamp_fmt = '%Y-%m-%dT%H:%M:%SZ'
_max_hostname = 255
_max_ident = 24 #safer for old daemons
_send_length = False
_tail = '\n'
ident = None
def __init__(self, address=('localhost', handlers.SYSLOG_UDP_PORT),
facility=handlers.SysLogHandler.LOG_USER,
socktype=None, ident=None, legacy=False):
"""
Initialize a handler.
If address is specified as a string, a UNIX socket is used. To log to a
local syslogd, "WebsockifySysLogHandler(address="/dev/log")" can be
used. If facility is not specified, LOG_USER is used. If socktype is
specified as socket.SOCK_DGRAM or socket.SOCK_STREAM, that specific
socket type will be used. For Unix sockets, you can also specify a
socktype of None, in which case socket.SOCK_DGRAM will be used, falling
back to socket.SOCK_STREAM. If ident is specified, this string will be
used as the application name in all messages sent. Set legacy to True
to use the old version of the protocol.
"""
self.ident = ident
if legacy:
self._legacy = True
self._head_fmt = self._legacy_head_fmt
super().__init__(address, facility, socktype)
def emit(self, record):
"""
Emit a record.
The record is formatted, and then sent to the syslog server. If
exception information is present, it is NOT sent to the server.
"""
try:
# Gather info.
text = self.format(record).replace(self._tail, ' ')
if not text: # nothing to log
return
pri = self.encodePriority(self.facility,
self.mapPriority(record.levelname))
timestamp = time.strftime(self._timestamp_fmt, time.gmtime());
hostname = socket.gethostname()[:self._max_hostname]
if self.ident:
ident = self.ident[:self._max_ident]
else:
ident = ''
pid = os.getpid() # shouldn't need truncation
# Format the header.
head = {
'pri': pri,
'timestamp': timestamp,
'hostname': hostname,
'ident': ident,
'pid': pid,
}
msg = self._head_fmt.format(**head).encode('ascii', 'ignore')
# Encode text as plain ASCII if possible, else use UTF-8 with BOM.
try:
msg += text.encode('ascii')
except UnicodeEncodeError:
msg += text.encode('utf-8-sig')
# Add length or tail character, if necessary.
if self.socktype != socket.SOCK_DGRAM:
if self._send_length:
msg = ('%d ' % len(msg)).encode('ascii') + msg
else:
msg += self._tail.encode('ascii')
# Send the message.
if self.unixsocket:
try:
self.socket.send(msg)
except socket.error:
self._connect_unixsocket(self.address)
self.socket.send(msg)
else:
if self.socktype == socket.SOCK_DGRAM:
self.socket.sendto(msg, self.address)
else:
self.socket.sendall(msg)
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)

View File

@@ -0,0 +1,316 @@
import logging
import os
import sys
import time
import re
import json
logger = logging.getLogger(__name__)
class BasePlugin():
def __init__(self, src):
self.source = src
def lookup(self, token):
return None
class ReadOnlyTokenFile(BasePlugin):
# source is a token file with lines like
# token: host:port
# or a directory of such files
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._targets = None
def _load_targets(self):
if os.path.isdir(self.source):
cfg_files = [os.path.join(self.source, f) for
f in os.listdir(self.source)]
else:
cfg_files = [self.source]
self._targets = {}
index = 1
for f in cfg_files:
for line in [l.strip() for l in open(f).readlines()]:
if line and not line.startswith('#'):
try:
tok, target = re.split(':\s', line)
self._targets[tok] = target.strip().rsplit(':', 1)
except ValueError:
logger.error("Syntax error in %s on line %d" % (self.source, index))
index += 1
def lookup(self, token):
if self._targets is None:
self._load_targets()
if token in self._targets:
return self._targets[token]
else:
return None
# the above one is probably more efficient, but this one is
# more backwards compatible (although in most cases
# ReadOnlyTokenFile should suffice)
class TokenFile(ReadOnlyTokenFile):
# source is a token file with lines like
# token: host:port
# or a directory of such files
def lookup(self, token):
self._load_targets()
return super().lookup(token)
class BaseTokenAPI(BasePlugin):
# source is a url with a '%s' in it where the token
# should go
# we import things on demand so that other plugins
# in this file can be used w/o unnecessary dependencies
def process_result(self, resp):
host, port = resp.text.split(':')
port = port.encode('ascii','ignore')
return [ host, port ]
def lookup(self, token):
import requests
resp = requests.get(self.source % token)
if resp.ok:
return self.process_result(resp)
else:
return None
class JSONTokenApi(BaseTokenAPI):
# source is a url with a '%s' in it where the token
# should go
def process_result(self, resp):
resp_json = resp.json()
return (resp_json['host'], resp_json['port'])
class JWTTokenApi(BasePlugin):
# source is a JWT-token, with hostname and port included
# Both JWS as JWE tokens are accepted. With regards to JWE tokens, the key is re-used for both validation and decryption.
def lookup(self, token):
try:
from jwcrypto import jwt, jwk
import json
key = jwk.JWK()
try:
with open(self.source, 'rb') as key_file:
key_data = key_file.read()
except Exception as e:
logger.error("Error loading key file: %s" % str(e))
return None
try:
key.import_from_pem(key_data)
except:
try:
key.import_key(k=key_data.decode('utf-8'),kty='oct')
except:
logger.error('Failed to correctly parse key data!')
return None
try:
token = jwt.JWT(key=key, jwt=token)
parsed_header = json.loads(token.header)
if 'enc' in parsed_header:
# Token is encrypted, so we need to decrypt by passing the claims to a new instance
token = jwt.JWT(key=key, jwt=token.claims)
parsed = json.loads(token.claims)
if 'nbf' in parsed:
# Not Before is present, so we need to check it
if time.time() < parsed['nbf']:
logger.warning('Token can not be used yet!')
return None
if 'exp' in parsed:
# Expiration time is present, so we need to check it
if time.time() > parsed['exp']:
logger.warning('Token has expired!')
return None
return (parsed['host'], parsed['port'])
except Exception as e:
logger.error("Failed to parse token: %s" % str(e))
return None
except ImportError:
logger.error("package jwcrypto not found, are you sure you've installed it correctly?")
return None
class TokenRedis(BasePlugin):
"""Token plugin based on the Redis in-memory data store.
The token source is in the format:
host[:port[:db[:password]]]
where port, db and password are optional. If port or db are left empty
they will take its default value, ie. 6379 and 0 respectively.
If your redis server is using the default port (6379) then you can use:
my-redis-host
In case you need to authenticate with the redis server and you are using
the default database and port you can use:
my-redis-host:::verysecretpass
In the more general case you will use:
my-redis-host:6380:1:verysecretpass
The TokenRedis plugin expects the format of the target in one of these two
formats:
- JSON
{"host": "target-host:target-port"}
- Plain text
target-host:target-port
Prepare data with:
redis-cli set my-token '{"host": "127.0.0.1:5000"}'
Verify with:
redis-cli --raw get my-token
Spawn a test "server" using netcat
nc -l 5000 -v
Note: This Token Plugin depends on the 'redis' module, so you have
to install it before using this plugin:
pip install redis
"""
def __init__(self, src):
try:
import redis
except ImportError:
logger.error("Unable to load redis module")
sys.exit()
# Default values
self._port = 6379
self._db = 0
self._password = None
try:
fields = src.split(":")
if len(fields) == 1:
self._server = fields[0]
elif len(fields) == 2:
self._server, self._port = fields
if not self._port:
self._port = 6379
elif len(fields) == 3:
self._server, self._port, self._db = fields
if not self._port:
self._port = 6379
if not self._db:
self._db = 0
elif len(fields) == 4:
self._server, self._port, self._db, self._password = fields
if not self._port:
self._port = 6379
if not self._db:
self._db = 0
if not self._password:
self._password = None
else:
raise ValueError
self._port = int(self._port)
self._db = int(self._db)
logger.info("TokenRedis backend initilized (%s:%s)" %
(self._server, self._port))
except ValueError:
logger.error("The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>]]]" %
src)
sys.exit()
def lookup(self, token):
try:
import redis
except ImportError:
logger.error("package redis not found, are you sure you've installed them correctly?")
sys.exit()
logger.info("resolving token '%s'" % token)
client = redis.Redis(host=self._server, port=self._port,
db=self._db, password=self._password)
stuff = client.get(token)
if stuff is None:
return None
else:
responseStr = stuff.decode("utf-8").strip()
logger.debug("response from redis : %s" % responseStr)
if responseStr.startswith("{"):
try:
combo = json.loads(responseStr)
host, port = combo["host"].split(":")
except ValueError:
logger.error("Unable to decode JSON token: %s" %
responseStr)
return None
except KeyError:
logger.error("Unable to find 'host' key in JSON token: %s" %
responseStr)
return None
elif re.match(r'\S+:\S+', responseStr):
host, port = responseStr.split(":")
else:
logger.error("Unable to parse token: %s" % responseStr)
return None
logger.debug("host: %s, port: %s" % (host, port))
return [host, port]
class UnixDomainSocketDirectory(BasePlugin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._dir_path = os.path.abspath(self.source)
def lookup(self, token):
try:
import stat
if not os.path.isdir(self._dir_path):
return None
uds_path = os.path.abspath(os.path.join(self._dir_path, token))
if not uds_path.startswith(self._dir_path):
return None
if not os.path.exists(uds_path):
return None
if not stat.S_ISSOCK(os.stat(uds_path).st_mode):
return None
return [ 'unix_socket', uds_path ]
except Exception as e:
logger.error("Error finding unix domain socket: %s" % str(e))
return None

View File

@@ -0,0 +1,874 @@
#!/usr/bin/env python
'''
Python WebSocket library
Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
Supports following protocol versions:
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
- http://tools.ietf.org/html/rfc6455
'''
import sys
import array
import email
import errno
import random
import socket
import ssl
import struct
from base64 import b64encode
from hashlib import sha1
from urllib.parse import urlparse
try:
import numpy
except ImportError:
import warnings
warnings.warn("no 'numpy' module, HyBi protocol will be slower")
numpy = None
class WebSocketWantReadError(ssl.SSLWantReadError):
pass
class WebSocketWantWriteError(ssl.SSLWantWriteError):
pass
class WebSocket(object):
"""WebSocket protocol socket like class.
This provides access to the WebSocket protocol by behaving much
like a real socket would. It shares many similarities with
ssl.SSLSocket.
The WebSocket protocols requires extra data to be sent and received
compared to the application level data. This means that a socket
that is ready to be read may not hold enough data to decode any
application data, and a socket that is ready to be written to may
not have enough space for an entire WebSocket frame. This is
handled by the exceptions WebSocketWantReadError and
WebSocketWantWriteError. When these are raised the caller must wait
for the socket to become ready again and call the relevant function
again.
A connection is established by using either connect() or accept(),
depending on if a client or server session is desired. See the
respective functions for details.
The following methods are passed on to the underlying socket:
- fileno
- getpeername, getsockname
- getsockopt, setsockopt
- gettimeout, settimeout
- setblocking
"""
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
def __init__(self):
"""Creates an unconnected WebSocket"""
self._state = "new"
self._partial_msg = b''
self._recv_buffer = b''
self._recv_queue = []
self._send_buffer = b''
self._previous_sendmsg = None
self._sent_close = False
self._received_close = False
self.close_code = None
self.close_reason = None
self.socket = None
def __getattr__(self, name):
# These methods are just redirected to the underlying socket
if name in ["fileno",
"getpeername", "getsockname",
"getsockopt", "setsockopt",
"gettimeout", "settimeout",
"setblocking"]:
assert self.socket is not None
return getattr(self.socket, name)
else:
raise AttributeError("%s instance has no attribute '%s'" %
(self.__class__.__name__, name))
def connect(self, uri, origin=None, protocols=[]):
"""Establishes a new connection to a WebSocket server.
This method connects to the host specified by uri and
negotiates a WebSocket connection. origin should be specified
in accordance with RFC 6454 if known. A list of valid
sub-protocols can be specified in the protocols argument.
The data will be sent in the clear if the "ws" scheme is used,
and encrypted if the "wss" scheme is used.
Both WebSocketWantReadError and WebSocketWantWriteError can be
raised whilst negotiating the connection. Repeated calls to
connect() must retain the same arguments.
"""
self.client = True;
uri = urlparse(uri)
port = uri.port
if uri.scheme in ("ws", "http"):
if not port:
port = 80
elif uri.scheme in ("wss", "https"):
if not port:
port = 443
else:
raise Exception("Unknown scheme '%s'" % uri.scheme)
# This is a state machine in order to handle
# WantRead/WantWrite events
if self._state == "new":
self.socket = socket.create_connection((uri.hostname, port))
if uri.scheme in ("wss", "https"):
self.socket = ssl.wrap_socket(self.socket)
self._state = "ssl_handshake"
else:
self._state = "headers"
if self._state == "ssl_handshake":
self.socket.do_handshake()
self._state = "headers"
if self._state == "headers":
self._key = ''
for i in range(16):
self._key += chr(random.randrange(256))
self._key = b64encode(self._key.encode("latin-1")).decode("ascii")
path = uri.path
if not path:
path = "/"
self.send_request("GET", path)
self.send_header("Host", uri.hostname)
self.send_header("Upgrade", "websocket")
self.send_header("Connection", "upgrade")
self.send_header("Sec-WebSocket-Key", self._key)
self.send_header("Sec-WebSocket-Version", 13)
if origin is not None:
self.send_header("Origin", origin)
if len(protocols) > 0:
self.send_header("Sec-WebSocket-Protocol", ", ".join(protocols))
self.end_headers()
self._state = "send_headers"
if self._state == "send_headers":
self._flush()
self._state = "response"
if self._state == "response":
if not self._recv():
raise Exception("Socket closed unexpectedly")
if self._recv_buffer.find(b'\r\n\r\n') == -1:
raise WebSocketWantReadError
(request, self._recv_buffer) = self._recv_buffer.split(b'\r\n', 1)
request = request.decode("latin-1")
words = request.split()
if (len(words) < 2) or (words[0] != "HTTP/1.1"):
raise Exception("Invalid response")
if words[1] != "101":
raise Exception("WebSocket request denied: %s" % " ".join(words[1:]))
(headers, self._recv_buffer) = self._recv_buffer.split(b'\r\n\r\n', 1)
headers = headers.decode('latin-1') + '\r\n'
headers = email.message_from_string(headers)
if headers.get("Upgrade", "").lower() != "websocket":
print(type(headers))
raise Exception("Missing or incorrect upgrade header")
accept = headers.get('Sec-WebSocket-Accept')
if accept is None:
raise Exception("Missing Sec-WebSocket-Accept header");
expected = sha1((self._key + self.GUID).encode("ascii")).digest()
expected = b64encode(expected).decode("ascii")
del self._key
if accept != expected:
raise Exception("Invalid Sec-WebSocket-Accept header");
self.protocol = headers.get('Sec-WebSocket-Protocol')
if len(protocols) == 0:
if self.protocol is not None:
raise Exception("Unexpected Sec-WebSocket-Protocol header")
else:
if self.protocol not in protocols:
raise Exception("Invalid protocol chosen by server")
self._state = "done"
return
raise Exception("WebSocket is in an invalid state")
def accept(self, socket, headers):
"""Establishes a new WebSocket session with a client.
This method negotiates a WebSocket connection with an incoming
client. The caller must provide the client socket and the
headers from the HTTP request.
A server can identify that a client is requesting a WebSocket
connection by looking at the "Upgrade" header. It will include
the value "websocket" in such cases.
WebSocketWantWriteError can be raised if the response cannot be
sent right away. accept() must be called again once more space
is available using the same arguments.
"""
# This is a state machine in order to handle
# WantRead/WantWrite events
if self._state == "new":
self.client = False
self.socket = socket
if headers.get("upgrade", "").lower() != "websocket":
raise Exception("Missing or incorrect upgrade header")
ver = headers.get('Sec-WebSocket-Version')
if ver is None:
raise Exception("Missing Sec-WebSocket-Version header");
# HyBi-07 report version 7
# HyBi-08 - HyBi-12 report version 8
# HyBi-13 reports version 13
if ver in ['7', '8', '13']:
self.version = "hybi-%02d" % int(ver)
else:
raise Exception("Unsupported protocol version %s" % ver)
key = headers.get('Sec-WebSocket-Key')
if key is None:
raise Exception("Missing Sec-WebSocket-Key header");
# Generate the hash value for the accept header
accept = sha1((key + self.GUID).encode("ascii")).digest()
accept = b64encode(accept).decode("ascii")
self.protocol = ''
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
if protocols:
self.protocol = self.select_subprotocol(protocols)
# We are required to choose one of the protocols
# presented by the client
if self.protocol not in protocols:
raise Exception('Invalid protocol selected')
self.send_response(101, "Switching Protocols")
self.send_header("Upgrade", "websocket")
self.send_header("Connection", "Upgrade")
self.send_header("Sec-WebSocket-Accept", accept)
if self.protocol:
self.send_header("Sec-WebSocket-Protocol", self.protocol)
self.end_headers()
self._state = "flush"
if self._state == "flush":
self._flush()
self._state = "done"
return
raise Exception("WebSocket is in an invalid state")
def select_subprotocol(self, protocols):
"""Returns which sub-protocol should be used.
This method does not select any sub-protocol by default and is
meant to be overridden by an implementation that wishes to make
use of sub-protocols. It will be called during handling of
accept().
"""
return ""
def handle_ping(self, data):
"""Called when a WebSocket ping message is received.
This will be called whilst processing recv()/recvmsg(). The
default implementation sends a pong reply back."""
self.pong(data)
def handle_pong(self, data):
"""Called when a WebSocket pong message is received.
This will be called whilst processing recv()/recvmsg(). The
default implementation does nothing."""
pass
def recv(self):
"""Read data from the WebSocket.
This will return any available data on the socket (which may
be the empty string if the peer sent an empty message or
messages). If the socket is closed then None will be
returned. The reason for the close is found in the
'close_code' and 'close_reason' properties.
Unlike recvmsg() this method may return data from more than one
WebSocket message. It is however not guaranteed to return all
buffered data. Callers should continue calling recv() whilst
pending() returns True.
Both WebSocketWantReadError and WebSocketWantWriteError can be
raised when calling recv().
"""
return self.recvmsg()
def recvmsg(self):
"""Read a single message from the WebSocket.
This will return a single WebSocket message from the socket
(which will be the empty string if the peer sent an empty
message). If the socket is closed then None will be
returned. The reason for the close is found in the
'close_code' and 'close_reason' properties.
Unlike recv() this method will not return data from more than
one WebSocket message. Callers should continue calling
recvmsg() whilst pending() returns True.
Both WebSocketWantReadError and WebSocketWantWriteError can be
raised when calling recvmsg().
"""
# May have been called to flush out a close
if self._received_close:
self._flush()
return None
# Anything already queued?
if self.pending():
return self._recvmsg()
# Note: If self._recvmsg() raised WebSocketWantReadError,
# we cannot proceed to self._recv() here as we may
# have already called it once as part of the caller's
# "while websock.pending():" loop
# Nope, let's try to read a bit
if not self._recv_frames():
return None
# Anything queued now?
return self._recvmsg()
def pending(self):
"""Check if any WebSocket data is pending.
This method will return True as long as there are WebSocket
frames that have yet been processed. A single recv() from the
underlying socket may return multiple WebSocket frames and it
is therefore important that a caller continues calling recv()
or recvmsg() as long as pending() returns True.
Note that this function merely tells if there are raw WebSocket
frames pending. Those frames may not contain any application
data.
"""
return len(self._recv_queue) > 0
def send(self, bytes):
"""Write data to the WebSocket
This will queue the given data and attempt to send it to the
peer. Unlike sendmsg() this method might coalesce the data with
data from other calls, or split it over multiple messages.
WebSocketWantWriteError can be raised if there is insufficient
space in the underlying socket. send() must be called again
once more space is available using the same arguments.
"""
if len(bytes) == 0:
return 0
return self.sendmsg(bytes)
def sendmsg(self, msg):
"""Write a single message to the WebSocket
This will queue the given message and attempt to send it to the
peer. Unlike send() this method will preserve the data as a
single WebSocket message.
WebSocketWantWriteError can be raised if there is insufficient
space in the underlying socket. sendmsg() must be called again
once more space is available using the same arguments.
"""
if not isinstance(msg, bytes):
raise TypeError
if self._sent_close:
return 0
if self._previous_sendmsg is not None:
if self._previous_sendmsg != msg:
raise ValueError
self._flush()
self._previous_sendmsg = None
return len(msg)
try:
self._sendmsg(0x2, msg)
except WebSocketWantWriteError:
self._previous_sendmsg = msg
raise
return len(msg)
def send_response(self, code, message):
self._queue_str("HTTP/1.1 %d %s\r\n" % (code, message))
def send_header(self, keyword, value):
self._queue_str("%s: %s\r\n" % (keyword, value))
def end_headers(self):
self._queue_str("\r\n")
def send_request(self, type, path):
self._queue_str("%s %s HTTP/1.1\r\n" % (type.upper(), path))
def ping(self, data=b''):
"""Write a ping message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient
space in the underlying socket. ping() must be called again once
more space is available using the same arguments.
"""
if not isinstance(data, bytes):
raise TypeError
if self._previous_sendmsg is not None:
if self._previous_sendmsg != data:
raise ValueError
self._flush()
self._previous_sendmsg = None
return
try:
self._sendmsg(0x9, data)
except WebSocketWantWriteError:
self._previous_sendmsg = data
raise
def pong(self, data=b''):
"""Write a pong message to the WebSocket
WebSocketWantWriteError can be raised if there is insufficient
space in the underlying socket. pong() must be called again once
more space is available using the same arguments.
"""
if not isinstance(data, bytes):
raise TypeError
if self._previous_sendmsg is not None:
if self._previous_sendmsg != data:
raise ValueError
self._flush()
self._previous_sendmsg = None
return
try:
self._sendmsg(0xA, data)
except WebSocketWantWriteError:
self._previous_sendmsg = data
raise
def shutdown(self, how, code=1000, reason=None):
"""Gracefully terminate the WebSocket connection.
This will start the process to terminate the WebSocket
connection. The caller must continue to calling recv() or
recvmsg() after this function in order to wait for the peer to
acknowledge the close. Calls to send() and sendmsg() will be
ignored.
WebSocketWantWriteError can be raised if there is insufficient
space in the underlying socket for the close message. shutdown()
must be called again once more space is available using the same
arguments.
The how argument is currently ignored.
"""
# Already closing?
if self._sent_close:
self._flush()
return
# Special code to indicate that we closed the connection
if not self._received_close:
self.close_code = 1000
self.close_reason = "Locally initiated close"
self._sent_close = True
msg = b''
if code is not None:
msg += struct.pack(">H", code)
if reason is not None:
msg += reason.encode("UTF-8")
self._sendmsg(0x8, msg)
def close(self, code=1000, reason=None):
"""Terminate the WebSocket connection immediately.
This will close the WebSocket connection directly after sending
a close message to the peer.
WebSocketWantWriteError can be raised if there is insufficient
space in the underlying socket for the close message. close()
must be called again once more space is available using the same
arguments.
"""
self.shutdown(socket.SHUT_RDWR, code, reason)
self._close()
def _recv(self):
# Fetches more data from the socket to the buffer
assert self.socket is not None
while True:
try:
data = self.socket.recv(4096)
except OSError as exc:
if exc.errno == errno.EWOULDBLOCK:
raise WebSocketWantReadError
raise
if len(data) == 0:
return False
self._recv_buffer += data
# Support for SSLSocket like objects
if hasattr(self.socket, "pending"):
if not self.socket.pending():
break
else:
break
return True
def _recv_frames(self):
# Fetches more data and decodes the frames
if not self._recv():
if self.close_code is None:
self.close_code = 1006
self.close_reason = "Connection closed abnormally"
self._sent_close = self._received_close = True
self._close()
return False
while True:
frame = self._decode_hybi(self._recv_buffer)
if frame is None:
break
self._recv_buffer = self._recv_buffer[frame['length']:]
self._recv_queue.append(frame)
return True
def _recvmsg(self):
# Process pending frames and returns any application data
while self._recv_queue:
frame = self._recv_queue.pop(0)
if not self.client and not frame['masked']:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked")
continue
if self.client and frame['masked']:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked")
continue
if frame["opcode"] == 0x0:
if not self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected continuation frame")
continue
self._partial_msg += frame["payload"]
if frame["fin"]:
msg = self._partial_msg
self._partial_msg = b''
return msg
elif frame["opcode"] == 0x1:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported")
elif frame["opcode"] == 0x2:
if self._partial_msg:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame")
continue
if frame["fin"]:
return frame["payload"]
else:
self._partial_msg = frame["payload"]
elif frame["opcode"] == 0x8:
if self._received_close:
continue
self._received_close = True
if self._sent_close:
self._close()
return None
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close")
continue
code = None
reason = None
if len(frame["payload"]) >= 2:
code = struct.unpack(">H", frame["payload"][:2])[0]
if len(frame["payload"]) > 2:
reason = frame["payload"][2:]
try:
reason = reason.decode("UTF-8")
except UnicodeDecodeError:
self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Invalid UTF-8 in close")
continue
if code is None:
self.close_code = code = 1005
self.close_reason = "No close status code specified by peer"
else:
self.close_code = code
if reason is not None:
self.close_reason = reason
self.shutdown(None, code, reason)
return None
elif frame["opcode"] == 0x9:
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping")
continue
self.handle_ping(frame["payload"])
elif frame["opcode"] == 0xA:
if not frame["fin"]:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong")
continue
self.handle_pong(frame["payload"])
else:
self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Unknown opcode 0x%02x" % frame["opcode"])
raise WebSocketWantReadError
def _flush(self):
# Writes pending data to the socket
if not self._send_buffer:
return
assert self.socket is not None
try:
sent = self.socket.send(self._send_buffer)
except OSError as exc:
if exc.errno == errno.EWOULDBLOCK:
raise WebSocketWantWriteError
raise
self._send_buffer = self._send_buffer[sent:]
if self._send_buffer:
raise WebSocketWantWriteError
# We had a pending close and we've flushed the buffer,
# time to end things
if self._received_close and self._sent_close:
self._close()
def _send(self, data):
# Queues data and attempts to send it
self._send_buffer += data
self._flush()
def _queue_str(self, string):
# Queue some data to be sent later.
# Only used by the connecting methods.
self._send_buffer += string.encode("latin-1")
def _sendmsg(self, opcode, msg):
# Sends a standard data message
if self.client:
mask = b''
for i in range(4):
mask += random.randrange(256)
frame = self._encode_hybi(opcode, msg, mask)
else:
frame = self._encode_hybi(opcode, msg)
return self._send(frame)
def _close(self):
# Close the underlying socket
self.socket.close()
self.socket = None
def _mask(self, buf, mask):
# Mask a frame
return self._unmask(buf, mask)
def _unmask(self, buf, mask):
# Unmask a frame
if numpy:
plen = len(buf)
pstart = 0
pend = plen
b = c = b''
if plen >= 4:
dtype=numpy.dtype('<u4')
if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>')
mask = numpy.frombuffer(mask, dtype, count=1)
data = numpy.frombuffer(buf, dtype, count=int(plen / 4))
#b = numpy.bitwise_xor(data, mask).data
b = numpy.bitwise_xor(data, mask).tobytes()
if plen % 4:
dtype=numpy.dtype('B')
if sys.byteorder == 'big':
dtype = dtype.newbyteorder('>')
mask = numpy.frombuffer(mask, dtype, count=(plen % 4))
data = numpy.frombuffer(buf, dtype,
offset=plen - (plen % 4), count=(plen % 4))
c = numpy.bitwise_xor(data, mask).tobytes()
return b + c
else:
# Slower fallback
data = array.array('B')
data.frombytes(buf)
for i in range(len(data)):
data[i] ^= mask[i % 4]
return data.tobytes()
def _encode_hybi(self, opcode, buf, mask_key=None, fin=True):
""" Encode a HyBi style WebSocket frame.
Optional opcode:
0x0 - continuation
0x1 - text frame
0x2 - binary frame
0x8 - connection close
0x9 - ping
0xA - pong
"""
b1 = opcode & 0x0f
if fin:
b1 |= 0x80
mask_bit = 0
if mask_key is not None:
mask_bit = 0x80
buf = self._mask(buf, mask_key)
payload_len = len(buf)
if payload_len <= 125:
header = struct.pack('>BB', b1, payload_len | mask_bit)
elif payload_len > 125 and payload_len < 65536:
header = struct.pack('>BBH', b1, 126 | mask_bit, payload_len)
elif payload_len >= 65536:
header = struct.pack('>BBQ', b1, 127 | mask_bit, payload_len)
if mask_key is not None:
return header + mask_key + buf
else:
return header + buf
def _decode_hybi(self, buf):
""" Decode HyBi style WebSocket packets.
Returns:
{'fin' : boolean,
'opcode' : number,
'masked' : boolean,
'length' : encoded_length,
'payload' : decoded_buffer}
"""
f = {'fin' : 0,
'opcode' : 0,
'masked' : False,
'length' : 0,
'payload' : None}
blen = len(buf)
hlen = 2
if blen < hlen:
return None
b1, b2 = struct.unpack(">BB", buf[:2])
f['opcode'] = b1 & 0x0f
f['fin'] = not not (b1 & 0x80)
f['masked'] = not not (b2 & 0x80)
if f['masked']:
hlen += 4
if blen < hlen:
return None
length = b2 & 0x7f
if length == 126:
hlen += 2
if blen < hlen:
return None
length, = struct.unpack('>H', buf[2:4])
elif length == 127:
hlen += 8
if blen < hlen:
return None
length, = struct.unpack('>Q', buf[2:10])
f['length'] = hlen + length
if blen < f['length']:
return None
if f['masked']:
# unmask payload
mask_key = buf[hlen-4:hlen]
f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key)
else:
f['payload'] = buf[hlen:(hlen+length)]
return f

View File

@@ -0,0 +1,800 @@
#!/usr/bin/env python
'''
A WebSocket to TCP socket proxy with support for "wss://" encryption.
Copyright 2011 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat
from socketserver import ThreadingMixIn
from http.server import HTTPServer
import select
from websockify import websockifyserver
from websockify import auth_plugins as auth
from urllib.parse import parse_qs, urlparse
class ProxyRequestHandler(websockifyserver.WebSockifyRequestHandler):
buffer_size = 65536
traffic_legend = """
Traffic Legend:
} - Client receive
}. - Client receive partial
{ - Target receive
> - Target send
>. - Target send partial
< - Client send
<. - Client send partial
"""
def send_auth_error(self, ex):
self.send_response(ex.code, ex.msg)
self.send_header('Content-Type', 'text/html')
for name, val in ex.headers.items():
self.send_header(name, val)
self.end_headers()
def validate_connection(self):
if not self.server.token_plugin:
return
host, port = self.get_target(self.server.token_plugin)
if host == 'unix_socket':
self.server.unix_target = port
else:
self.server.target_host = host
self.server.target_port = port
def auth_connection(self):
if not self.server.auth_plugin:
return
try:
# get client certificate data
client_cert_data = self.request.getpeercert()
# extract subject information
client_cert_subject = client_cert_data['subject']
# flatten data structure
client_cert_subject = dict([x[0] for x in client_cert_subject])
# add common name to headers (apache +StdEnvVars style)
self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName']
except (TypeError, AttributeError, KeyError):
# not a SSL connection or client presented no certificate with valid data
pass
try:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
except auth.AuthenticationError:
ex = sys.exc_info()[1]
self.send_auth_error(ex)
raise
def new_websocket_client(self):
"""
Called after a new WebSocket connection has been established.
"""
# Checking for a token is done in validate_connection()
# Connect to the target
if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
elif self.server.unix_target:
msg = "connecting to unix socket: %s" % self.server.unix_target
else:
msg = "connecting to: %s:%s" % (
self.server.target_host, self.server.target_port)
if self.server.ssl_target:
msg += " (using SSL)"
self.log_message(msg)
try:
tsock = websockifyserver.WebSockifyServer.socket(self.server.target_host,
self.server.target_port,
connect=True,
use_ssl=self.server.ssl_target,
unix_socket=self.server.unix_target)
except Exception as e:
self.log_message("Failed to connect to %s:%s: %s",
self.server.target_host, self.server.target_port, e)
raise self.CClose(1011, "Failed to connect to downstream server")
# Option unavailable when listening to unix socket
if not self.server.unix_listen:
self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if not self.server.wrap_cmd and not self.server.unix_target:
tsock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
self.print_traffic(self.traffic_legend)
# Start proxying
try:
self.do_proxy(tsock)
finally:
if tsock:
tsock.shutdown(socket.SHUT_RDWR)
tsock.close()
if self.verbose:
self.log_message("%s:%s: Closed target",
self.server.target_host, self.server.target_port)
def get_target(self, target_plugin):
"""
Gets a token from either the path or the host,
depending on --host-token, and looks up a target
for that token using the token plugin. Used by
validate_connection() to set target_host and target_port.
"""
# The files in targets contain the lines
# in the form of token: host:port
if self.host_token:
# Use hostname as token
token = self.headers.get('Host')
# Remove port from hostname, as it'll always be the one where
# websockify listens (unless something between the client and
# websockify is redirecting traffic, but that's beside the point)
if token:
token = token.partition(':')[0]
else:
# Extract the token parameter from url
args = parse_qs(urlparse(self.path)[4]) # 4 is the query from url
if 'token' in args and len(args['token']):
token = args['token'][0].rstrip('\n')
else:
token = None
if token is None:
raise self.server.EClose("Token not present")
result_pair = target_plugin.lookup(token)
if result_pair is not None:
return result_pair
else:
raise self.server.EClose("Token '%s' not found" % token)
def do_proxy(self, target):
"""
Proxy client WebSocket to normal target socket.
"""
cqueue = []
c_pend = 0
tqueue = []
rlist = [self.request, target]
if self.server.heartbeat:
now = time.time()
self.heartbeat = now + self.server.heartbeat
else:
self.heartbeat = None
while True:
wlist = []
if self.heartbeat is not None:
now = time.time()
if now > self.heartbeat:
self.heartbeat = now + self.server.heartbeat
self.send_ping()
if tqueue: wlist.append(target)
if cqueue or c_pend: wlist.append(self.request)
try:
ins, outs, excepts = select.select(rlist, wlist, [], 1)
except (select.error, OSError):
exc = sys.exc_info()[1]
if hasattr(exc, 'errno'):
err = exc.errno
else:
err = exc[0]
if err != errno.EINTR:
raise
else:
continue
if excepts: raise Exception("Socket exception")
if self.request in outs:
# Send queued target data to the client
c_pend = self.send_frames(cqueue)
cqueue = []
if self.request in ins:
# Receive client data, decode it, and queue for target
bufs, closed = self.recv_frames()
tqueue.extend(bufs)
if closed:
while (len(tqueue) != 0):
# Send queued client data to the target
dat = tqueue.pop(0)
sent = target.send(dat)
if sent == len(dat):
self.print_traffic(">")
else:
# requeue the remaining data
tqueue.insert(0, dat[sent:])
self.print_traffic(".>")
# TODO: What about blocking on client socket?
if self.verbose:
self.log_message("%s:%s: Client closed connection",
self.server.target_host, self.server.target_port)
raise self.CClose(closed['code'], closed['reason'])
if target in outs:
# Send queued client data to the target
dat = tqueue.pop(0)
sent = target.send(dat)
if sent == len(dat):
self.print_traffic(">")
else:
# requeue the remaining data
tqueue.insert(0, dat[sent:])
self.print_traffic(".>")
if target in ins:
# Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size)
if len(buf) == 0:
# Target socket closed, flushing queues and closing client-side websocket
# Send queued target data to the client
if len(cqueue) != 0:
c_pend = True
while(c_pend):
c_pend = self.send_frames(cqueue)
cqueue = []
if self.verbose:
self.log_message("%s:%s: Target closed connection",
self.server.target_host, self.server.target_port)
raise self.CClose(1000, "Target closed")
cqueue.append(buf)
self.print_traffic("{")
class WebSocketProxy(websockifyserver.WebSockifyServer):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
socket server target.
"""
buffer_size = 65536
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.heartbeat = kwargs.pop('heartbeat', None)
self.token_plugin = kwargs.pop('token_plugin', None)
self.host_token = kwargs.pop('host_token', None)
self.auth_plugin = kwargs.pop('auth_plugin', None)
# Last 3 timestamps command was run
self.wrap_times = [0, 0, 0]
if self.wrap_cmd:
wsdir = os.path.dirname(sys.argv[0])
rebinder_path = [os.path.join(wsdir, "..", "lib"),
os.path.join(wsdir, "..", "lib", "websockify"),
os.path.join(wsdir, ".."),
wsdir]
self.rebinder = None
for rdir in rebinder_path:
rpath = os.path.join(rdir, "rebind.so")
if os.path.exists(rpath):
self.rebinder = rpath
break
if not self.rebinder:
raise Exception("rebind.so not found, perhaps you need to run make")
self.rebinder = os.path.abspath(self.rebinder)
self.target_host = "127.0.0.1" # Loopback
# Find a free high port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0))
self.target_port = sock.getsockname()[1]
sock.close()
# Insert rebinder at the head of the (possibly empty) LD_PRELOAD pathlist
ld_preloads = filter(None, [ self.rebinder, os.environ.get("LD_PRELOAD", None) ])
os.environ.update({
"LD_PRELOAD": os.pathsep.join(ld_preloads),
"REBIND_OLD_PORT": str(kwargs['listen_port']),
"REBIND_NEW_PORT": str(self.target_port)})
super().__init__(RequestHandlerClass, *args, **kwargs)
def run_wrap_cmd(self):
self.msg("Starting '%s'", " ".join(self.wrap_cmd))
self.wrap_times.append(time.time())
self.wrap_times.pop(0)
self.cmd = subprocess.Popen(
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
self.spawn_message = True
def started(self):
"""
Called after Websockets server startup (i.e. after daemonize)
"""
# Need to call wrapped command after daemonization so we can
# know when the wrapped command exits
if self.wrap_cmd:
dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port)
elif self.unix_target:
dst_string = self.unix_target
else:
dst_string = "%s:%s" % (self.target_host, self.target_port)
if self.listen_fd != None:
src_string = "inetd"
else:
src_string = "%s:%s" % (self.listen_host, self.listen_port)
if self.token_plugin:
msg = " - proxying from %s to targets generated by %s" % (
src_string, type(self.token_plugin).__name__)
else:
msg = " - proxying from %s to %s" % (
src_string, dst_string)
if self.ssl_target:
msg += " (using SSL)"
self.msg("%s", msg)
if self.wrap_cmd:
self.run_wrap_cmd()
def poll(self):
# If we are wrapping a command, check it's status
if self.wrap_cmd and self.cmd:
ret = self.cmd.poll()
if ret != None:
self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret)
self.cmd = None
if self.wrap_cmd and self.cmd == None:
# Response to wrapped command being gone
if self.wrap_mode == "ignore":
pass
elif self.wrap_mode == "exit":
sys.exit(ret)
elif self.wrap_mode == "respawn":
now = time.time()
avg = sum(self.wrap_times)/len(self.wrap_times)
if (now - avg) < 10:
# 3 times in the last 10 seconds
if self.spawn_message:
self.warn("Command respawning too fast")
self.spawn_message = False
else:
self.run_wrap_cmd()
def _subprocess_setup():
# Python installs a SIGPIPE handler by default. This is usually not what
# non-Python successfulbprocesses expect.
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
SSL_OPTIONS = {
'default': ssl.OP_ALL,
'tlsv1_1': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1,
'tlsv1_2': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1,
'tlsv1_3': ssl.PROTOCOL_SSLv23 | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 |
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2,
}
def select_ssl_version(version):
"""Returns SSL options for the most secure TSL version available on this
Python version"""
if version in SSL_OPTIONS:
return SSL_OPTIONS[version]
else:
# It so happens that version names sorted lexicographically form a list
# from the least to the most secure
keys = list(SSL_OPTIONS.keys())
keys.sort()
fallback = keys[-1]
logger = logging.getLogger(WebSocketProxy.log_prefix)
logger.warn("TLS version %s unsupported. Falling back to %s",
version, fallback)
return SSL_OPTIONS[fallback]
def websockify_init():
# Setup basic logging to stderr.
stderr_handler = logging.StreamHandler()
stderr_handler.setLevel(logging.DEBUG)
log_formatter = logging.Formatter("%(message)s")
stderr_handler.setFormatter(log_formatter)
root = logging.getLogger()
root.addHandler(stderr_handler)
root.setLevel(logging.INFO)
# Setup optparse.
usage = "\n %prog [options]"
usage += " [source_addr:]source_port [target_addr:target_port]"
usage += "\n %prog [options]"
usage += " --token-plugin=CLASS [source_addr:]source_port"
usage += "\n %prog [options]"
usage += " --unix-target=FILE [source_addr:]source_port"
usage += "\n %prog [options]"
usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
parser = optparse.OptionParser(usage=usage)
parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages")
parser.add_option("--traffic", action="store_true",
help="per frame traffic")
parser.add_option("--record",
help="record sessions to FILE.[session_number]", metavar="FILE")
parser.add_option("--daemon", "-D",
dest="daemon", action="store_true",
help="become a daemon (background process)")
parser.add_option("--run-once", action="store_true",
help="handle a single WebSocket connection and exit")
parser.add_option("--timeout", type=int, default=0,
help="after TIMEOUT seconds exit when not connected")
parser.add_option("--idle-timeout", type=int, default=0,
help="server exits after TIMEOUT seconds if there are no "
"active connections")
parser.add_option("--cert", default="self.pem",
help="SSL certificate file")
parser.add_option("--key", default=None,
help="SSL key file (if separate from cert)")
parser.add_option("--key-password", default=None,
help="SSL key password")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted client connections")
parser.add_option("--ssl-target", action="store_true",
help="connect to SSL target as SSL client")
parser.add_option("--verify-client", action="store_true",
help="require encrypted client to present a valid certificate "
"(needs Python 2.7.9 or newer or Python 3.4 or newer)")
parser.add_option("--cafile", metavar="FILE",
help="file of concatenated certificates of authorities trusted "
"for validating clients (only effective with --verify-client). "
"If omitted, system default list of CAs is used.")
parser.add_option("--ssl-version", type="choice", default="default",
choices=["default", "tlsv1_1", "tlsv1_2", "tlsv1_3"], action="store",
help="minimum TLS version to use (default, tlsv1_1, tlsv1_2, tlsv1_3)")
parser.add_option("--ssl-ciphers", action="store",
help="list of ciphers allowed for connection. For a list of "
"supported ciphers run `openssl ciphers`")
parser.add_option("--unix-listen",
help="listen to unix socket", metavar="FILE", default=None)
parser.add_option("--unix-listen-mode", default=None,
help="specify mode for unix socket (defaults to 0600)")
parser.add_option("--unix-target",
help="connect to unix socket target", metavar="FILE")
parser.add_option("--inetd",
help="inetd mode, receive listening socket from stdin", action="store_true")
parser.add_option("--web", default=None, metavar="DIR",
help="run webserver on same port. Serve files from DIR.")
parser.add_option("--web-auth", action="store_true",
help="require authentication to access webserver.")
parser.add_option("--wrap-mode", default="exit", metavar="MODE",
choices=["exit", "ignore", "respawn"],
help="action to take when the wrapped program exits "
"or daemonizes: exit (default), ignore, respawn")
parser.add_option("--prefer-ipv6", "-6",
action="store_true", dest="source_is_ipv6",
help="prefer IPv6 when resolving source_addr")
parser.add_option("--libserver", action="store_true",
help="use Python library SocketServer engine")
parser.add_option("--target-config", metavar="FILE",
dest="target_cfg",
help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form "
"(DEPRECATED: use `--token-plugin TokenFile --token-source "
" path/to/token/file` instead)")
parser.add_option("--token-plugin", default=None, metavar="CLASS",
help="use a Python class, usually one from websockify.token_plugins, "
"such as TokenFile, to process tokens into host:port pairs")
parser.add_option("--token-source", default=None, metavar="ARG",
help="an argument to be passed to the token plugin "
"on instantiation")
parser.add_option("--host-token", action="store_true",
help="use the host HTTP header as token instead of the "
"token URL query parameter")
parser.add_option("--auth-plugin", default=None, metavar="CLASS",
help="use a Python class, usually one from websockify.auth_plugins, "
"such as BasicHTTPAuth, to determine if a connection is allowed")
parser.add_option("--auth-source", default=None, metavar="ARG",
help="an argument to be passed to the auth plugin "
"on instantiation")
parser.add_option("--heartbeat", type=int, default=0, metavar="INTERVAL",
help="send a ping to the client every INTERVAL seconds")
parser.add_option("--log-file", metavar="FILE",
dest="log_file",
help="File where logs will be saved")
parser.add_option("--syslog", default=None, metavar="SERVER",
help="Log to syslog server. SERVER can be local socket, "
"such as /dev/log, or a UDP host:port pair.")
parser.add_option("--legacy-syslog", action="store_true",
help="Use the old syslog protocol instead of RFC 5424. "
"Use this if the messages produced by websockify seem abnormal.")
parser.add_option("--file-only", action="store_true",
help="use this to disable directory listings in web server.")
(opts, args) = parser.parse_args()
# Validate options.
if opts.token_source and not opts.token_plugin:
parser.error("You must use --token-plugin to use --token-source")
if opts.host_token and not opts.token_plugin:
parser.error("You must use --token-plugin to use --host-token")
if opts.auth_source and not opts.auth_plugin:
parser.error("You must use --auth-plugin to use --auth-source")
if opts.web_auth and not opts.auth_plugin:
parser.error("You must use --auth-plugin to use --web-auth")
if opts.web_auth and not opts.web:
parser.error("You must use --web to use --web-auth")
if opts.legacy_syslog and not opts.syslog:
parser.error("You must use --syslog to use --legacy-syslog")
opts.ssl_options = select_ssl_version(opts.ssl_version)
del opts.ssl_version
if opts.log_file:
# Setup logging to user-specified file.
opts.log_file = os.path.abspath(opts.log_file)
log_file_handler = logging.FileHandler(opts.log_file)
log_file_handler.setLevel(logging.DEBUG)
log_file_handler.setFormatter(log_formatter)
root = logging.getLogger()
root.addHandler(log_file_handler)
del opts.log_file
if opts.syslog:
# Determine how to connect to syslog...
if opts.syslog.count(':'):
# User supplied a host:port pair.
syslog_host, syslog_port = opts.syslog.rsplit(':', 1)
try:
syslog_port = int(syslog_port)
except ValueError:
parser.error("Error parsing syslog port")
syslog_dest = (syslog_host, syslog_port)
else:
# User supplied a local socket file.
syslog_dest = os.path.abspath(opts.syslog)
from websockify.sysloghandler import WebsockifySysLogHandler
# Determine syslog facility.
if opts.daemon:
syslog_facility = WebsockifySysLogHandler.LOG_DAEMON
else:
syslog_facility = WebsockifySysLogHandler.LOG_USER
# Start logging to syslog.
syslog_handler = WebsockifySysLogHandler(address=syslog_dest,
facility=syslog_facility,
ident='websockify',
legacy=opts.legacy_syslog)
syslog_handler.setLevel(logging.DEBUG)
syslog_handler.setFormatter(log_formatter)
root = logging.getLogger()
root.addHandler(syslog_handler)
del opts.syslog
del opts.legacy_syslog
if opts.verbose:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
# Transform to absolute path as daemon may chdir
if opts.target_cfg:
opts.target_cfg = os.path.abspath(opts.target_cfg)
if opts.target_cfg:
opts.token_plugin = 'TokenFile'
opts.token_source = opts.target_cfg
del opts.target_cfg
if sys.argv.count('--'):
opts.wrap_cmd = args[1:]
else:
opts.wrap_cmd = None
if not websockifyserver.ssl and opts.ssl_target:
parser.error("SSL target requested and Python SSL module not loaded.");
if opts.ssl_only and not os.path.exists(opts.cert):
parser.error("SSL only and %s not found" % opts.cert)
if opts.inetd:
opts.listen_fd = sys.stdin.fileno()
elif opts.unix_listen:
if opts.unix_listen_mode:
try:
# Parse octal notation (like 750)
opts.unix_listen_mode = int(opts.unix_listen_mode, 8)
except ValueError:
parser.error("Error parsing listen unix socket mode")
else:
# Default to 0600 (Owner Read/Write)
opts.unix_listen_mode = stat.S_IREAD | stat.S_IWRITE
else:
if len(args) < 1:
parser.error("Too few arguments")
arg = args.pop(0)
# Parse host:port and convert ports to numbers
if arg.count(':') > 0:
opts.listen_host, opts.listen_port = arg.rsplit(':', 1)
opts.listen_host = opts.listen_host.strip('[]')
else:
opts.listen_host, opts.listen_port = '', arg
try:
opts.listen_port = int(opts.listen_port)
except ValueError:
parser.error("Error parsing listen port")
del opts.inetd
if opts.wrap_cmd or opts.unix_target or opts.token_plugin:
opts.target_host = None
opts.target_port = None
else:
if len(args) < 1:
parser.error("Too few arguments")
arg = args.pop(0)
if arg.count(':') > 0:
opts.target_host, opts.target_port = arg.rsplit(':', 1)
opts.target_host = opts.target_host.strip('[]')
else:
parser.error("Error parsing target")
try:
opts.target_port = int(opts.target_port)
except ValueError:
parser.error("Error parsing target port")
if len(args) > 0 and opts.wrap_cmd == None:
parser.error("Too many arguments")
if opts.token_plugin is not None:
if '.' not in opts.token_plugin:
opts.token_plugin = (
'websockify.token_plugins.%s' % opts.token_plugin)
token_plugin_module, token_plugin_cls = opts.token_plugin.rsplit('.', 1)
__import__(token_plugin_module)
token_plugin_cls = getattr(sys.modules[token_plugin_module], token_plugin_cls)
opts.token_plugin = token_plugin_cls(opts.token_source)
del opts.token_source
if opts.auth_plugin is not None:
if '.' not in opts.auth_plugin:
opts.auth_plugin = 'websockify.auth_plugins.%s' % opts.auth_plugin
auth_plugin_module, auth_plugin_cls = opts.auth_plugin.rsplit('.', 1)
__import__(auth_plugin_module)
auth_plugin_cls = getattr(sys.modules[auth_plugin_module], auth_plugin_cls)
opts.auth_plugin = auth_plugin_cls(opts.auth_source)
del opts.auth_source
# Create and start the WebSockets proxy
libserver = opts.libserver
del opts.libserver
if libserver:
# Use standard Python SocketServer framework
server = LibProxyServer(**opts.__dict__)
server.serve_forever()
else:
# Use internal service framework
server = WebSocketProxy(**opts.__dict__)
server.start_server()
class LibProxyServer(ThreadingMixIn, HTTPServer):
"""
Just like WebSocketProxy, but uses standard Python SocketServer
framework.
"""
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.token_plugin = kwargs.pop('token_plugin', None)
self.auth_plugin = kwargs.pop('auth_plugin', None)
self.heartbeat = kwargs.pop('heartbeat', None)
self.token_plugin = None
self.auth_plugin = None
self.daemon = False
# Server configuration
listen_host = kwargs.pop('listen_host', '')
listen_port = kwargs.pop('listen_port', None)
web = kwargs.pop('web', '')
# Configuration affecting base request handler
self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False)
record = kwargs.pop('record', '')
if record:
self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False)
self.handler_id = 0
for arg in kwargs.keys():
print("warning: option %s ignored when using --libserver" % arg)
if web:
os.chdir(web)
super().__init__((listen_host, listen_port), RequestHandlerClass)
def process_request(self, request, client_address):
"""Override process_request to implement a counter"""
self.handler_id += 1
super().process_request(request, client_address)
if __name__ == '__main__':
websockify_init()

View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python
'''
Python WebSocket server base
Copyright 2011 Joel Martin
Copyright 2016-2018 Pierre Ossman
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
'''
import sys
from http.server import BaseHTTPRequestHandler, HTTPServer
from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError
class HttpWebSocket(WebSocket):
"""Class to glue websocket and http request functionality together"""
def __init__(self, request_handler):
super().__init__()
self.request_handler = request_handler
def send_response(self, code, message=None):
self.request_handler.send_response(code, message)
def send_header(self, keyword, value):
self.request_handler.send_header(keyword, value)
def end_headers(self):
self.request_handler.end_headers()
class WebSocketRequestHandlerMixIn:
"""WebSocket request handler mix-in class
This class modifies and existing request handler to handle
WebSocket requests. The request handler will continue to function
as before, except that WebSocket requests are intercepted and the
methods handle_upgrade() and handle_websocket() are called. The
standard do_GET() will be called for normal requests.
The class instance SocketClass can be overridden with the class to
use for the WebSocket connection.
"""
SocketClass = HttpWebSocket
def handle_one_request(self):
"""Extended request handler
This is where WebSocketRequestHandler redirects requests to the
new methods. Any sub-classes must call this method in order for
the calls to function.
"""
self._real_do_GET = self.do_GET
self.do_GET = self._websocket_do_GET
try:
super().handle_one_request()
finally:
self.do_GET = self._real_do_GET
def _websocket_do_GET(self):
# Checks if it is a websocket request and redirects
self.do_GET = self._real_do_GET
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
self.handle_upgrade()
else:
self.do_GET()
def handle_upgrade(self):
"""Initial handler for a WebSocket request
This method is called when a WebSocket is requested. By default
it will create a WebSocket object and perform the negotiation.
The WebSocket object will then replace the request object and
handle_websocket() will be called.
"""
websocket = self.SocketClass(self)
try:
websocket.accept(self.request, self.headers)
except Exception:
exc = sys.exc_info()[1]
self.send_error(400, str(exc))
return
self.request = websocket
# Other requests cannot follow Websocket data
self.close_connection = True
self.handle_websocket()
def handle_websocket(self):
"""Handle a WebSocket connection.
This is called when the WebSocket is ready to be used. A
sub-class should perform the necessary communication here and
return once done.
"""
pass
# Convenient ready made classes
class WebSocketRequestHandler(WebSocketRequestHandlerMixIn,
BaseHTTPRequestHandler):
pass
class WebSocketServer(HTTPServer):
pass

View File

@@ -0,0 +1,862 @@
#!/usr/bin/env python
'''
Python WebSocket server base with support for "wss://" encryption.
Copyright 2011 Joel Martin
Copyright 2016 Pierre Ossman
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
import os, sys, time, errno, signal, socket, select, logging
import multiprocessing
from http.server import SimpleHTTPRequestHandler
# Degraded functionality if these imports are missing
for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'),
('resource', 'daemonizing is disabled')]:
try:
globals()[mod] = __import__(mod)
except ImportError:
globals()[mod] = None
print("WARNING: no '%s' module, %s" % (mod, msg))
if sys.platform == 'win32':
# make sockets pickle-able/inheritable
import multiprocessing.reduction
from websockify.websocket import WebSocketWantReadError, WebSocketWantWriteError
from websockify.websocketserver import WebSocketRequestHandlerMixIn
class CompatibleWebSocket(WebSocketRequestHandlerMixIn.SocketClass):
def select_subprotocol(self, protocols):
# Handle old websockify clients that still specify a sub-protocol
if 'binary' in protocols:
return 'binary'
else:
return ''
# HTTP handler with WebSocket upgrade support
class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHandler):
"""
WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler.
Must be sub-classed with new_websocket_client method definition.
The request handler can be configured by setting optional
attributes on the server object:
* only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled,
only websocket is allowed.
* verbose: If true, verbose logging is activated.
* daemon: Running as daemon, do not write to console etc
* record: Record raw frame data as JavaScript array into specified filename
* run_once: Handle a single request
* handler_id: A sequence number for this connection, appended to record filename
"""
server_version = "WebSockify"
protocol_version = "HTTP/1.1"
SocketClass = CompatibleWebSocket
# An exception while the WebSocket client was connected
class CClose(Exception):
pass
def __init__(self, req, addr, server):
# Retrieve a few configuration variables from the server
self.only_upgrade = getattr(server, "only_upgrade", False)
self.verbose = getattr(server, "verbose", False)
self.daemon = getattr(server, "daemon", False)
self.record = getattr(server, "record", False)
self.run_once = getattr(server, "run_once", False)
self.rec = None
self.handler_id = getattr(server, "handler_id", False)
self.file_only = getattr(server, "file_only", False)
self.traffic = getattr(server, "traffic", False)
self.web_auth = getattr(server, "web_auth", False)
self.host_token = getattr(server, "host_token", False)
self.logger = getattr(server, "logger", None)
if self.logger is None:
self.logger = WebSockifyServer.get_logger()
super().__init__(req, addr, server)
def log_message(self, format, *args):
self.logger.info("%s - - [%s] %s" % (self.client_address[0], self.log_date_time_string(), format % args))
#
# WebSocketRequestHandler logging/output functions
#
def print_traffic(self, token="."):
""" Show traffic flow mode. """
if self.traffic:
sys.stdout.write(token)
sys.stdout.flush()
def msg(self, msg, *args, **kwargs):
""" Output message with handler_id prefix. """
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs)
def vmsg(self, msg, *args, **kwargs):
""" Same as msg() but as debug. """
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs)
def warn(self, msg, *args, **kwargs):
""" Same as msg() but as warning. """
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs)
#
# Main WebSocketRequestHandler methods
#
def send_frames(self, bufs=None):
""" Encode and send WebSocket frames. Any frames already
queued will be sent first. If buf is not set then only queued
frames will be sent. Returns True if any frames could not be
fully sent, in which case the caller should call again when
the socket is ready. """
tdelta = int(time.time()*1000) - self.start_time
if bufs:
for buf in bufs:
if self.rec:
# Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
self.rec.write("'{{{0}{{{1}',\n".format(tdelta, bufstr))
self.send_parts.append(buf)
while self.send_parts:
# Send pending frames
try:
self.request.sendmsg(self.send_parts[0])
except WebSocketWantWriteError:
self.print_traffic("<.")
return True
self.send_parts.pop(0)
self.print_traffic("<")
return False
def recv_frames(self):
""" Receive and decode WebSocket frames.
Returns:
(bufs_list, closed_string)
"""
closed = False
bufs = []
tdelta = int(time.time()*1000) - self.start_time
while True:
try:
buf = self.request.recvmsg()
except WebSocketWantReadError:
self.print_traffic("}.")
break
if buf is None:
closed = {'code': self.request.close_code,
'reason': self.request.close_reason}
return bufs, closed
self.print_traffic("}")
if self.rec:
# Python 3 compatible conversion
bufstr = buf.decode('latin1').encode('unicode_escape').decode('ascii').replace("'", "\\'")
self.rec.write("'}}{0}}}{1}',\n".format(tdelta, bufstr))
bufs.append(buf)
if not self.request.pending():
break
return bufs, closed
def send_close(self, code=1000, reason=''):
""" Send a WebSocket orderly close frame. """
self.request.shutdown(socket.SHUT_RDWR, code, reason)
def send_pong(self, data=''.encode('ascii')):
""" Send a WebSocket pong frame. """
self.request.pong(data)
def send_ping(self, data=''.encode('ascii')):
""" Send a WebSocket ping frame. """
self.request.ping(data)
def handle_upgrade(self):
# ensure connection is authorized, and determine the target
self.validate_connection()
self.auth_connection()
super().handle_upgrade()
def handle_websocket(self):
# Indicate to server that a Websocket upgrade was done
self.server.ws_connection = True
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.start_time = int(time.time()*1000)
# client_address is empty with, say, UNIX domain sockets
client_addr = ""
is_ssl = False
try:
client_addr = self.client_address[0]
is_ssl = self.client_address[2]
except IndexError:
pass
if is_ssl:
self.stype = "SSL/TLS (wss://)"
else:
self.stype = "Plain non-SSL (ws://)"
self.log_message("%s: %s WebSocket connection", client_addr,
self.stype)
if self.path != '/':
self.log_message("%s: Path: '%s'", client_addr, self.path)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.log_message("opening record file: %s", fname)
self.rec = open(fname, 'w+')
self.rec.write("var VNC_frame_data = [\n")
try:
self.new_websocket_client()
except self.CClose:
# Close the client
_, exc, _ = sys.exc_info()
self.send_close(exc.args[0], exc.args[1])
def do_GET(self):
if self.web_auth:
# ensure connection is authorized, this seems to apply to list_directory() as well
self.auth_connection()
if self.only_upgrade:
self.send_error(405)
else:
super().do_GET()
def list_directory(self, path):
if self.file_only:
self.send_error(404)
else:
return super().list_directory(path)
def new_websocket_client(self):
""" Do something with a WebSockets client connection. """
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
def validate_connection(self):
""" Ensure that the connection has a valid token, and set the target. """
pass
def auth_connection(self):
""" Ensure that the connection is authorized. """
pass
def do_HEAD(self):
if self.web_auth:
self.auth_connection()
if self.only_upgrade:
self.send_error(405)
else:
super().do_HEAD()
def finish(self):
if self.rec:
self.rec.write("'EOF'];\n")
self.rec.close()
super().finish()
def handle(self):
# When using run_once, we have a single process, so
# we cannot loop in BaseHTTPRequestHandler.handle; we
# must return and handle new connections
if self.run_once:
self.handle_one_request()
else:
super().handle()
def log_request(self, code='-', size='-'):
if self.verbose:
super().log_request(code, size)
class WebSockifyServer():
"""
WebSockets server class.
As an alternative, the standard library SocketServer can be used
"""
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
log_prefix = "websocket"
# An exception before the WebSocket connection was established
class EClose(Exception):
pass
class Terminate(Exception):
pass
def __init__(self, RequestHandlerClass, listen_fd=None,
listen_host='', listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', key_password=None, ssl_only=None,
verify_client=False, cafile=None,
daemon=False, record='', web='', web_auth=False,
file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0,
unix_listen=None, unix_listen_mode=None):
# settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose
self.listen_fd = listen_fd
self.unix_listen = unix_listen
self.unix_listen_mode = unix_listen_mode
self.listen_host = listen_host
self.listen_port = listen_port
self.prefer_ipv6 = source_is_ipv6
self.ssl_only = ssl_only
self.ssl_ciphers = ssl_ciphers
self.ssl_options = ssl_options
self.verify_client = verify_client
self.daemon = daemon
self.run_once = run_once
self.timeout = timeout
self.idle_timeout = idle_timeout
self.traffic = traffic
self.file_only = file_only
self.web_auth = web_auth
self.launch_time = time.time()
self.ws_connection = False
self.handler_id = 1
self.terminating = False
self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt
self.tcp_keepidle = tcp_keepidle
self.tcp_keepintvl = tcp_keepintvl
# keyfile path must be None if not specified
self.key = None
self.key_password = key_password
# Make paths settings absolute
self.cert = os.path.abspath(cert)
self.web = self.record = self.cafile = ''
if key:
self.key = os.path.abspath(key)
if web:
self.web = os.path.abspath(web)
if record:
self.record = os.path.abspath(record)
if cafile:
self.cafile = os.path.abspath(cafile)
if self.web:
os.chdir(self.web)
self.only_upgrade = not self.web
# Sanity checks
if not ssl and self.ssl_only:
raise Exception("No 'ssl' module and SSL-only specified")
if self.daemon and not resource:
raise Exception("Module 'resource' required to daemonize")
# Show configuration
self.msg("WebSocket server settings:")
if self.listen_fd != None:
self.msg(" - Listen for inetd connections")
elif self.unix_listen != None:
self.msg(" - Listen on unix socket %s", self.unix_listen)
else:
self.msg(" - Listen on %s:%s",
self.listen_host, self.listen_port)
if self.web:
if self.file_only:
self.msg(" - Web server (no directory listings). Web root: %s", self.web)
else:
self.msg(" - Web server. Web root: %s", self.web)
if ssl:
if os.path.exists(self.cert):
self.msg(" - SSL/TLS support")
if self.ssl_only:
self.msg(" - Deny non-SSL/TLS connections")
else:
self.msg(" - No SSL/TLS support (no cert file)")
else:
self.msg(" - No SSL/TLS support (no 'ssl' module)")
if self.daemon:
self.msg(" - Backgrounding (daemon)")
if self.record:
self.msg(" - Recording to '%s.*'", self.record)
#
# WebSockifyServer static methods
#
@staticmethod
def get_logger():
return logging.getLogger("%s.%s" % (
WebSockifyServer.log_prefix,
WebSockifyServer.__class__.__name__))
@staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False,
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None,
tcp_keepidle=None, tcp_keepintvl=None):
""" Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket.
"""
flags = 0
if host == '':
host = None
if connect and not (port or unix_socket):
raise Exception("Connect mode requires a port")
if use_ssl and not ssl:
raise Exception("SSL socket requested but Python SSL module not loaded.");
if not connect and use_ssl:
raise Exception("SSL only supported in connect mode (for now)")
if not connect:
flags = flags | socket.AI_PASSIVE
if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
socket.IPPROTO_TCP, flags)
if not addrs:
raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0])
if prefer_ipv6:
addrs.reverse()
sock = socket.socket(addrs[0][0], addrs[0][1])
if tcp_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if tcp_keepcnt:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT,
tcp_keepcnt)
if tcp_keepidle:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
tcp_keepidle)
if tcp_keepintvl:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
tcp_keepintvl)
if connect:
sock.connect(addrs[0][4])
if use_ssl:
sock = ssl.wrap_socket(sock)
else:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(addrs[0][4])
sock.listen(100)
else:
if unix_socket_listen:
# Make sure the socket does not already exist
try:
os.unlink(unix_socket)
except FileNotFoundError:
pass
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
oldmask = os.umask(0o777 ^ unix_socket_mode)
try:
sock.bind(unix_socket)
finally:
os.umask(oldmask)
sock.listen(100)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)
return sock
@staticmethod
def daemonize(keepfd=None, chdir='/'):
if keepfd is None:
keepfd = []
os.umask(0)
if chdir:
os.chdir(chdir)
else:
os.chdir('/')
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
# Signal handling
signal.signal(signal.SIGTERM, signal.SIG_IGN)
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256
for fd in reversed(range(maxfd)):
try:
if fd not in keepfd:
os.close(fd)
except OSError:
_, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise
# Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
def do_handshake(self, sock, address):
"""
do_handshake does the following:
- Peek at the first few bytes from the socket.
- If the connection is an HTTPS/SSL/TLS connection then SSL
wrap the socket.
- Read from the (possibly wrapped) socket.
- If we have received a HTTP GET request and the webserver
functionality is enabled, answer it, close the socket and
return.
- Assume we have a WebSockets connection, parse the client
handshake data.
- Send a WebSockets handshake server response.
- Return the socket for this WebSocket client.
"""
ready = select.select([sock], [], [], 3)[0]
if not ready:
raise self.EClose("")
# Peek, but do not read the data so that we have a opportunity
# to SSL wrap the socket first
handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % handshake)
if not handshake:
raise self.EClose("")
elif handshake[0] in (22, 128):
# SSL wrap the connection
if not ssl:
raise self.EClose("SSL connection but no 'ssl' module")
if not os.path.exists(self.cert):
raise self.EClose("SSL connection but '%s' not found"
% self.cert)
retsock = None
try:
# create new-style SSL wrapping for extended features
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if self.ssl_ciphers is not None:
context.set_ciphers(self.ssl_ciphers)
context.options = self.ssl_options
context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password)
if self.verify_client:
context.verify_mode = ssl.CERT_REQUIRED
if self.cafile:
context.load_verify_locations(cafile=self.cafile)
else:
context.set_default_verify_paths()
retsock = context.wrap_socket(
sock,
server_side=True)
except ssl.SSLError:
_, x, _ = sys.exc_info()
if x.args[0] == ssl.SSL_ERROR_EOF:
if len(x.args) > 1:
raise self.EClose(x.args[1])
else:
raise self.EClose("Got SSL_ERROR_EOF")
else:
raise
elif self.ssl_only:
raise self.EClose("non-SSL connection received but disallowed")
else:
retsock = sock
# If the address is like (host, port), we are extending it
# with a flag indicating SSL. Not many other options
# available...
if len(address) == 2:
address = (address[0], address[1], (retsock != sock))
self.RequestHandlerClass(retsock, address, self)
# Return the WebSockets socket which may be SSL wrapped
return retsock
#
# WebSockifyServer logging/output functions
#
def msg(self, *args, **kwargs):
""" Output message as info """
self.logger.log(logging.INFO, *args, **kwargs)
def vmsg(self, *args, **kwargs):
""" Same as msg() but as debug. """
self.logger.log(logging.DEBUG, *args, **kwargs)
def warn(self, *args, **kwargs):
""" Same as msg() but as warning. """
self.logger.log(logging.WARN, *args, **kwargs)
#
# Events that can/should be overridden in sub-classes
#
def started(self):
""" Called after WebSockets startup """
self.vmsg("WebSockets server started")
def poll(self):
""" Run periodically while waiting for connections. """
#self.vmsg("Running poll()")
pass
def terminate(self):
if not self.terminating:
self.terminating = True
raise self.Terminate()
def multiprocessing_SIGCHLD(self, sig, stack):
# TODO: figure out a way to actually log this information without
# calling `log` in the signal handlers
multiprocessing.active_children()
def fallback_SIGCHLD(self, sig, stack):
# Reap zombies when using os.fork() (python 2.4)
# TODO: figure out a way to actually log this information without
# calling `log` in the signal handlers
try:
result = os.waitpid(-1, os.WNOHANG)
while result[0]:
self.vmsg("Reaped child process %s" % result[0])
result = os.waitpid(-1, os.WNOHANG)
except (OSError):
pass
def do_SIGINT(self, sig, stack):
# TODO: figure out a way to actually log this information without
# calling `log` in the signal handlers
self.terminate()
def do_SIGTERM(self, sig, stack):
# TODO: figure out a way to actually log this information without
# calling `log` in the signal handlers
self.terminate()
def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
# handler process
client = None
try:
try:
client = self.do_handshake(startsock, address)
except self.EClose:
_, exc, _ = sys.exc_info()
# Connection was not a WebSockets connection
if exc.args[0]:
self.msg("%s: %s" % (address[0], exc.args[0]))
except WebSockifyServer.Terminate:
raise
except Exception:
_, exc, _ = sys.exc_info()
self.msg("handler exception: %s" % str(exc))
self.vmsg("exception", exc_info=True)
finally:
if client and client != startsock:
# Close the SSL wrapped socket
# Original socket closed by caller
client.close()
def get_log_fd(self):
"""
Get file descriptors for the loggers.
They should not be closed when the process is forked.
"""
descriptors = []
for handler in self.logger.parent.handlers:
if isinstance(handler, logging.FileHandler):
descriptors.append(handler.stream.fileno())
return descriptors
def start_server(self):
"""
Daemonize if requested. Listen for for connections. Run
do_handshake() method for each connection. If the connection
is a WebSockets client then call new_websocket_client() method (which must
be overridden) for each new client connection.
"""
if self.listen_fd != None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
elif self.unix_listen != None:
lsock = self.socket(host=None,
unix_socket=self.unix_listen,
unix_socket_mode=self.unix_listen_mode,
unix_socket_listen=True)
else:
lsock = self.socket(self.listen_host, self.listen_port, False,
self.prefer_ipv6,
tcp_keepalive=self.tcp_keepalive,
tcp_keepcnt=self.tcp_keepcnt,
tcp_keepidle=self.tcp_keepidle,
tcp_keepintvl=self.tcp_keepintvl)
if self.daemon:
keepfd = self.get_log_fd()
keepfd.append(lsock.fileno())
self.daemonize(keepfd=keepfd, chdir=self.web)
self.started() # Some things need to happen after daemonizing
# Allow override of signals
original_signals = {
signal.SIGINT: signal.getsignal(signal.SIGINT),
signal.SIGTERM: signal.getsignal(signal.SIGTERM),
}
if getattr(signal, 'SIGCHLD', None) is not None:
original_signals[signal.SIGCHLD] = signal.getsignal(signal.SIGCHLD)
signal.signal(signal.SIGINT, self.do_SIGINT)
signal.signal(signal.SIGTERM, self.do_SIGTERM)
# make sure that _cleanup is called when children die
# by calling active_children on SIGCHLD
if getattr(signal, 'SIGCHLD', None) is not None:
signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD)
last_active_time = self.launch_time
try:
while True:
try:
try:
startsock = None
pid = err = 0
child_count = 0
# Collect zombie child processes
child_count = len(multiprocessing.active_children())
time_elapsed = time.time() - self.launch_time
if self.timeout and time_elapsed > self.timeout:
self.msg('listener exit due to --timeout %s'
% self.timeout)
break
if self.idle_timeout:
idle_time = 0
if child_count == 0:
idle_time = time.time() - last_active_time
else:
idle_time = 0
last_active_time = time.time()
if idle_time > self.idle_timeout and child_count == 0:
self.msg('listener exit due to --idle-timeout %s'
% self.idle_timeout)
break
try:
self.poll()
ready = select.select([lsock], [], [], 1)[0]
if lsock in ready:
startsock, address = lsock.accept()
# Unix Socket will not report address (empty string), but address[0] is logged a bunch
if self.unix_listen != None:
address = [ self.unix_listen ]
else:
continue
except self.Terminate:
raise
except Exception:
_, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'):
err = exc.errno
elif hasattr(exc, 'args'):
err = exc.args[0]
else:
err = exc[0]
if err == errno.EINTR:
self.vmsg("Ignoring interrupted syscall")
continue
else:
raise
if self.run_once:
# Run in same process if run_once
self.top_new_client(startsock, address)
if self.ws_connection :
self.msg('%s: exiting due to --run-once'
% address[0])
break
else:
self.vmsg('%s: new handler Process' % address[0])
p = multiprocessing.Process(
target=self.top_new_client,
args=(startsock, address))
p.start()
# child will not return
# parent process
self.handler_id += 1
except (self.Terminate, SystemExit, KeyboardInterrupt):
self.msg("In exit")
# terminate all child processes
if not self.run_once:
children = multiprocessing.active_children()
for child in children:
self.msg("Terminating child %s" % child.pid)
child.terminate()
break
except Exception:
exc = sys.exc_info()[1]
self.msg("handler exception: %s", str(exc))
self.vmsg("exception", exc_info=True)
finally:
if startsock:
startsock.close()
finally:
# Close listen port
self.vmsg("Closing socket listening at %s:%s",
self.listen_host, self.listen_port)
lsock.close()
# Restore signals
for sig, func in original_signals.items():
signal.signal(sig, func)