Browse Source

Fix bugs introduced by invalid video id fix:

* Rooms break when reaching end of playlist
Soof 6 days ago
parent
commit
1a1800aa24
3 changed files with 31 additions and 21 deletions
  1. 10 7
      channel.py
  2. 14 10
      chube.py
  3. 7 4
      chube_ws.py

+ 10 - 7
channel.py

@@ -1,27 +1,27 @@
 from typing import Dict
 
-from websockets import WebSocketServerProtocol
-
+from websockets.asyncio.server import ServerConnection
+from websockets.exceptions import ConnectionClosed
 
 class Subscriber:
     player_enabled = False
-    ws: WebSocketServerProtocol
+    ws: ServerConnection
 
     def __init__(self, ws):
         self.ws = ws
 
 
 class Channel:
-    subscribers: Dict[WebSocketServerProtocol, Subscriber]
+    subscribers: Dict[ServerConnection, Subscriber]
 
     def __init__(self):
         self.subscribers = dict()
 
-    def subscribe(self, ws: WebSocketServerProtocol):
+    def subscribe(self, ws: ServerConnection):
         if ws not in self.subscribers:
             self.subscribers[ws] = (Subscriber(ws))
 
-    def unsubscribe(self, ws: WebSocketServerProtocol):
+    def unsubscribe(self, ws: ServerConnection):
         if ws in self.subscribers:
             self.subscribers.pop(ws)
 
@@ -30,7 +30,10 @@ class Channel:
 
     async def send(self, message):
         for sub in list(self.subscribers.values()):
-            await sub.ws.send(message)
+            try:
+                await sub.ws.send(message)
+            except ConnectionClosed:
+                self.unsubscribe(sub.ws)
             # else:
             #     print("closed ws still in channel")
             #     self.unsubscribe(sub.ws)

+ 14 - 10
chube.py

@@ -6,6 +6,7 @@ from typing import Optional, Iterator, Dict, List
 
 import sys
 from itertools import cycle
+from websockets.asyncio.server import ServerConnection
 
 import chube_youtube
 from channel import Channel, Subscriber
@@ -131,7 +132,10 @@ class Playback:
     def set_song(self, song):
         with self.lock:
             self._song = song
-            logger.debug("Playback %s: Set song to %d", self, song["id"])
+            if song is not None:
+                logger.debug("Playback %s: Set song to %d", self, song["id"])
+            else:
+                logger.debug("Playback %s: finished last song", self)
 
     def get_song(self):
         with self.lock:
@@ -177,7 +181,7 @@ class Room:
 rooms: Dict[str, Room] = dict()
 
 
-async def request_state_processor(ws, _, path):
+async def request_state_processor(ws: ServerConnection, _, path):
     room = rooms[path]
     state = {
         "lists": room.chueue.as_lists(),
@@ -290,10 +294,10 @@ async def obtain_control_processor(ws, data, path):
     await obtain_control(ws, room)
 
 
-async def release_control_processor(ws, data, path):
+async def release_control_processor(ws: ServerConnection, data, path):
     room = rooms[path]
     if len(room.channel.subscribers) > 1:
-        await release_control(ws, room)
+        await release_control(ws, False, room)
     else:
         pass
 
@@ -329,12 +333,12 @@ async def player_enabled_processor(ws, data, path):
             if room.get_controller() is None:
                 await obtain_control(ws, room)
     else:
-        await release_control(ws, room)
+        await release_control(ws, False, room)
 
 
 # TODO change OBTAIN_CONTROL en RELEASE_CONTROL to one message
 # TODO There is some potential concurrent bug here, when the controller loses/releases control right before a song end.
-async def obtain_control(ws, room: Room):
+async def obtain_control(ws: ServerConnection, room: Room):
     with room.controller_lock:
         controller = room.get_controller()
         if controller is None or controller.ws is not ws:
@@ -344,7 +348,7 @@ async def obtain_control(ws, room: Room):
                 await controller.ws.send(make_message(Message.RELEASE_CONTROL))
 
 
-async def release_control(ws, room: Room):
+async def release_control(ws: ServerConnection, ws_disconnected: bool, room: Room):
     with room.controller_lock:
         controller = room.get_controller()
         if controller is not None and controller.ws is ws:
@@ -352,8 +356,8 @@ async def release_control(ws, room: Room):
             room.set_controller(controller)
             if controller is not None:
                 await controller.ws.send(make_message(Message.OBTAIN_CONTROL))
-            # if ws.open:
-            await ws.send(make_message(Message.RELEASE_CONTROL))
+            if not ws_disconnected:
+                await ws.send(make_message(Message.RELEASE_CONTROL))
 
 
 async def on_connect(ws, path):
@@ -371,7 +375,7 @@ async def on_connect(ws, path):
 async def on_disconnect(ws, path):
     room = rooms[path]
     room.channel.unsubscribe(ws)
-    await release_control(ws, room)
+    await release_control(ws, True, room)
     print("Currently {} user{} {} using room {}".format(
         len(room.channel.subscribers),
         "s" if len(room.channel.subscribers) != 1 else "",

+ 7 - 4
chube_ws.py

@@ -3,9 +3,10 @@ import json
 import logging
 import os
 import pathlib
+import signal
 import ssl
 
-from websockets.asyncio.server import serve
+from websockets.asyncio.server import serve, ServerConnection
 from websockets.exceptions import ConnectionClosed
 
 from chube_enums import Message
@@ -60,15 +61,15 @@ class Resolver:
             return self._registerDict[message_type], message["__body"]
 
     def make_handler(self, on_open=None, on_close=None):
-        async def on_open_handler(websocket, path):
+        async def on_open_handler(websocket: ServerConnection, path):
             if on_open is not None:
                 await on_open(websocket, path)
 
-        async def on_close_handler(websocket, path):
+        async def on_close_handler(websocket: ServerConnection, path):
             if on_close is not None:
                 await on_close(websocket, path)
 
-        async def handler(websocket):
+        async def handler(websocket: ServerConnection):
             path = websocket.request.path
             await on_open_handler(websocket, path.lower())
             try:
@@ -98,6 +99,7 @@ def make_message_from_json_string(message_type, raw_body: str):
 
 
 async def start_server(resolver: Resolver, on_new_connection, on_connection_close):
+    loop = asyncio.get_running_loop()
     ssl_context = None
     if ENABLE_WSS:
         ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@@ -108,5 +110,6 @@ async def start_server(resolver: Resolver, on_new_connection, on_connection_clos
     async with serve(
             resolver.make_handler(on_open=on_new_connection, on_close=on_connection_close),
             HOST, PORT, ssl=ssl_context) as server:
+        loop.add_signal_handler(signal.SIGINT, lambda s: s.close(), server)
         await server.serve_forever()