chube_ws.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import asyncio
  2. import json
  3. import logging
  4. import os
  5. import pathlib
  6. import ssl
  7. from websockets.asyncio.server import serve
  8. from websockets.exceptions import ConnectionClosed
  9. from chube_enums import Message
  10. logger = logging.getLogger('chube')
  11. PORT = os.environ.get("CHUBE_WS_PORT") or 38210 # CHU
  12. HOST = os.environ.get("CHUBE_WS_HOST") or "localhost"
  13. ENABLE_WSS = os.environ.get("CHUBE_NO_WSS") != '1'
  14. CERT_PATH = os.environ.get("CHUBE_CERT_PATH")
  15. KEY_PATH = os.environ.get("CHUBE_KEY_PATH")
  16. if ENABLE_WSS and (CERT_PATH is None or not os.path.isfile(CERT_PATH)):
  17. raise Exception("WSS is enabled but no valid certificate is provided. To disable WSS provide the CHUBE_NO_WSS=1 "
  18. "environment variable.\nProvided certificate path is {}".format(CERT_PATH))
  19. if ENABLE_WSS and (CERT_PATH is None or not os.path.isfile(KEY_PATH)):
  20. raise Exception("WSS is enabled but no valid key is provided. To disable WSS provide the CHUBE_NO_WSS=1 "
  21. "environment variable.\nProvided key path is {}".format(KEY_PATH))
  22. class MessageResolveException(Exception):
  23. pass
  24. class Resolver:
  25. _registerDict: dict = {}
  26. def register(self, message: Message, handler):
  27. self._registerDict[message.value] = handler
  28. def unregister(self, message):
  29. return self._registerDict.pop(message.value)
  30. def resolve(self, data):
  31. message = json.loads(data)
  32. if not isinstance(message, dict):
  33. raise MessageResolveException("Received bytes is not a json object but a {}. {}".format(type(message), message))
  34. if "__message" not in message:
  35. raise MessageResolveException("Received message does not have required '__message' field. {}".format(message))
  36. message_type = message["__message"]
  37. if message_type not in self._registerDict:
  38. raise MessageResolveException("No handler for message type {}. {}".format(message_type, message))
  39. if "__body" not in message:
  40. return self._registerDict[message_type], None
  41. else:
  42. return self._registerDict[message_type], message["__body"]
  43. def make_handler(self, on_open=None, on_close=None):
  44. async def on_open_handler(websocket, path):
  45. if on_open is not None:
  46. await on_open(websocket, path)
  47. async def on_close_handler(websocket, path):
  48. if on_close is not None:
  49. await on_close(websocket, path)
  50. async def handler(websocket):
  51. path = websocket.request.path
  52. await on_open_handler(websocket, path.lower())
  53. try:
  54. while True:
  55. message = await websocket.recv()
  56. logger.debug(f"{path} WebsocketMessage {{{message}}}")
  57. processor, body = self.resolve(message)
  58. await processor(websocket, body, path.lower())
  59. except MessageResolveException as e:
  60. logger.exception(e)
  61. except ConnectionClosed:
  62. await on_close_handler(websocket, path.lower())
  63. return handler
  64. def add_all(self, search_resolver: "Resolver"):
  65. for message, handler in search_resolver._registerDict.items():
  66. self._registerDict[message] = handler
  67. def make_message(message_type, body=None):
  68. return json.dumps({"__message": message_type.value, "__body": body})
  69. def make_message_from_json_string(message_type, raw_body: str):
  70. return "{{\"__message\": \"{}\", \"__body\": {}}}".format(message_type.value, raw_body)
  71. async def start_server(resolver: Resolver, on_new_connection, on_connection_close):
  72. ssl_context = None
  73. if ENABLE_WSS:
  74. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  75. cert_pem = pathlib.Path(CERT_PATH)
  76. key_pem = pathlib.Path(KEY_PATH)
  77. ssl_context.load_cert_chain(cert_pem, key_pem)
  78. async with serve(
  79. resolver.make_handler(on_open=on_new_connection, on_close=on_connection_close),
  80. HOST, PORT, ssl=ssl_context) as server:
  81. await server.serve_forever()