chube_ws.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import asyncio
  2. import json
  3. import os
  4. import pathlib
  5. import ssl
  6. import websockets
  7. from chube_enums import Message
  8. PORT = os.environ.get("CHUBE_WS_PORT") or 3821 # CHU
  9. HOST = os.environ.get("CHUBE_WS_HOST") or "localhost"
  10. ENABLE_WSS = os.environ.get("CHUBE_NO_WSS") != 1
  11. CERT_PATH = os.environ.get("CHUBE_CERT_PATH")
  12. KEY_PATH = os.environ.get("CHUBE_KEY_PATH")
  13. if ENABLE_WSS and (CERT_PATH is None or not os.path.isfile(CERT_PATH)):
  14. raise Exception("WSS is enabled but no valid certificate is provided. To disable WSS provide the CHUBE_NO_WSS=1 "
  15. "environment variable.\nProvided certificate path is {}".format(CERT_PATH))
  16. if ENABLE_WSS and (CERT_PATH is None or not os.path.isfile(KEY_PATH)):
  17. raise Exception("WSS is enabled but no valid key is provided. To disable WSS provide the CHUBE_NO_WSS=1 "
  18. "environment variable.\nProvided key path is {}".format(KEY_PATH))
  19. class Resolver:
  20. _registerDict: dict = {}
  21. def register(self, message: Message, handler):
  22. self._registerDict[message.value] = handler
  23. def unregister(self, message):
  24. return self._registerDict.pop(message.value)
  25. def resolve(self, data):
  26. message = json.loads(data)
  27. if not isinstance(message, dict):
  28. raise Exception("Received bytes is not a json object but a {}. {}".format(type(message), message))
  29. if "__message" not in message:
  30. raise Exception("Received message does not have required '__message' field. {}".format(message))
  31. message_type = message["__message"]
  32. if message_type not in self._registerDict:
  33. raise Exception("No handler for message type {}. {}".format(message_type, message))
  34. if "__body" not in message:
  35. return self._registerDict[message_type], None
  36. else:
  37. return self._registerDict[message_type], message["__body"]
  38. def make_handler(self, on_open=None, on_close=None):
  39. async def on_open_handler(websocket, path):
  40. if on_open is not None:
  41. await on_open(websocket, path)
  42. async def on_close_handler(websocket, path):
  43. if on_close is not None:
  44. await on_close(websocket, path)
  45. async def handler(websocket, path):
  46. await on_open_handler(websocket, path.lower())
  47. try:
  48. while True:
  49. message = await websocket.recv()
  50. processor, body = self.resolve(message)
  51. await processor(websocket, body, path.lower())
  52. except websockets.ConnectionClosed:
  53. await on_close_handler(websocket, path.lower())
  54. return handler
  55. def add_all(self, search_resolver: "Resolver"):
  56. for message, handler in search_resolver._registerDict.items():
  57. self._registerDict[message] = handler
  58. def make_message(message_type, body=None):
  59. return json.dumps({"__message": message_type.value, "__body": body})
  60. def make_message_from_json_string(message_type, raw_body: str):
  61. return "{{\"__message\": \"{}\", \"__body\": {}}}".format(message_type.value, raw_body)
  62. def start_server(resolver: Resolver, on_new_connection, on_connection_close):
  63. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  64. cert_pem = pathlib.Path(CERT_PATH)
  65. key_pem = pathlib.Path(KEY_PATH)
  66. ssl_context.load_cert_chain(cert_pem, key_pem)
  67. ws_server = websockets.serve(
  68. resolver.make_handler(on_open=on_new_connection, on_close=on_connection_close),
  69. HOST, PORT, ssl=ssl_context)
  70. asyncio.get_event_loop().run_until_complete(ws_server)
  71. asyncio.get_event_loop().run_forever()