chube_ws.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import asyncio
  2. import json
  3. import logging
  4. import os
  5. import pathlib
  6. import signal
  7. import ssl
  8. from websockets.asyncio.server import serve, ServerConnection
  9. from websockets.exceptions import ConnectionClosed
  10. from chube_enums import Message
  11. logger = logging.getLogger('chube')
  12. PORT = os.environ.get("CHUBE_WS_PORT") or 38210 # CHU
  13. HOST = os.environ.get("CHUBE_WS_HOST") or "localhost"
  14. ENABLE_WSS = os.environ.get("CHUBE_NO_WSS") != '1'
  15. CERT_PATH = os.environ.get("CHUBE_CERT_PATH")
  16. KEY_PATH = os.environ.get("CHUBE_KEY_PATH")
  17. if ENABLE_WSS and (CERT_PATH is None or not os.path.isfile(CERT_PATH)):
  18. raise Exception("WSS is enabled but no valid certificate is provided. To disable WSS provide the CHUBE_NO_WSS=1 "
  19. "environment variable.\nProvided certificate path is {}".format(CERT_PATH))
  20. if ENABLE_WSS and (CERT_PATH is None or not os.path.isfile(KEY_PATH)):
  21. raise Exception("WSS is enabled but no valid key is provided. To disable WSS provide the CHUBE_NO_WSS=1 "
  22. "environment variable.\nProvided key path is {}".format(KEY_PATH))
  23. class MessageResolveException(Exception):
  24. pass
  25. class Resolver:
  26. _registerDict: dict = {}
  27. def register(self, message: Message, handler):
  28. self._registerDict[message.value] = handler
  29. def unregister(self, message):
  30. return self._registerDict.pop(message.value)
  31. def resolve(self, data):
  32. message = json.loads(data)
  33. if not isinstance(message, dict):
  34. raise MessageResolveException("Received bytes is not a json object but a {}. {}".format(type(message), message))
  35. if "__message" not in message:
  36. raise MessageResolveException("Received message does not have required '__message' field. {}".format(message))
  37. message_type = message["__message"]
  38. if message_type not in self._registerDict:
  39. raise MessageResolveException("No handler for message type {}. {}".format(message_type, message))
  40. if "__body" not in message:
  41. return self._registerDict[message_type], None
  42. else:
  43. return self._registerDict[message_type], message["__body"]
  44. def make_handler(self, on_open=None, on_close=None):
  45. async def on_open_handler(websocket: ServerConnection, path):
  46. if on_open is not None:
  47. await on_open(websocket, path)
  48. async def on_close_handler(websocket: ServerConnection, path):
  49. if on_close is not None:
  50. await on_close(websocket, path)
  51. async def handler(websocket: ServerConnection):
  52. path = websocket.request.path
  53. await on_open_handler(websocket, path.lower())
  54. try:
  55. while True:
  56. message = await websocket.recv()
  57. logger.debug(f"{path} WebsocketMessage {{{message}}}")
  58. processor, body = self.resolve(message)
  59. await processor(websocket, body, path.lower())
  60. except MessageResolveException as e:
  61. logger.exception(e)
  62. except ConnectionClosed:
  63. await on_close_handler(websocket, path.lower())
  64. return handler
  65. def add_all(self, search_resolver: "Resolver"):
  66. for message, handler in search_resolver._registerDict.items():
  67. self._registerDict[message] = handler
  68. def make_message(message_type, body=None):
  69. return json.dumps({"__message": message_type.value, "__body": body})
  70. def make_message_from_json_string(message_type, raw_body: str):
  71. return "{{\"__message\": \"{}\", \"__body\": {}}}".format(message_type.value, raw_body)
  72. async def start_server(resolver: Resolver, on_new_connection, on_connection_close):
  73. loop = asyncio.get_running_loop()
  74. ssl_context = None
  75. if ENABLE_WSS:
  76. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  77. cert_pem = pathlib.Path(CERT_PATH)
  78. key_pem = pathlib.Path(KEY_PATH)
  79. ssl_context.load_cert_chain(cert_pem, key_pem)
  80. async with serve(
  81. resolver.make_handler(on_open=on_new_connection, on_close=on_connection_close),
  82. HOST, PORT, ssl=ssl_context) as server:
  83. loop.add_signal_handler(signal.SIGINT, lambda s: s.close(), server)
  84. await server.serve_forever()