diff --git a/aircox_streamer/controllers/__init__.py b/aircox_streamer/controllers/__init__.py index 6fd5578..8b0bd03 100644 --- a/aircox_streamer/controllers/__init__.py +++ b/aircox_streamer/controllers/__init__.py @@ -3,13 +3,19 @@ from .metadata import Metadata, Request from .streamer import Streamer +from .streamers import Streamers from .sources import Source, PlaylistSource, QueueSource +streamers = Streamers() +"""Default controller used by views and viewsets.""" + + __all__ = ( "Metadata", "Request", "Streamer", + "Streamers", "Source", "PlaylistSource", "QueueSource", diff --git a/aircox_streamer/tests/conftest.py b/aircox_streamer/tests/conftest.py index 49ba3c9..791309e 100644 --- a/aircox_streamer/tests/conftest.py +++ b/aircox_streamer/tests/conftest.py @@ -1,3 +1,4 @@ +import itertools import os from datetime import datetime, time @@ -74,21 +75,53 @@ def station(): @pytest.fixture -def station_ports(station): - items = [ - models.Port( - station=station, - direction=models.Port.DIRECTION_INPUT, - type=models.Port.TYPE_HTTP, +def stations(station): + objs = [ + models.Station( + name=f"test-{i}", + slug=f"test-{i}", + path=working_dir, + default=(i == 0), active=True, - ), - models.Port( - station=station, - direction=models.Port.DIRECTION_OUTPUT, - type=models.Port.TYPE_FILE, - active=True, - ), + ) + for i in range(0, 3) ] + models.Station.objects.bulk_create(objs) + return [station] + objs + + +@pytest.fixture +def station_ports(station): + return _stations_ports(station) + + +@pytest.fixture +def stations_ports(stations): + return _stations_ports(*stations) + + +def _stations_ports(*stations): + items = list( + itertools.chain( + *[ + ( + models.Port( + station=station, + direction=models.Port.DIRECTION_INPUT, + type=models.Port.TYPE_HTTP, + active=True, + ), + models.Port( + station=station, + direction=models.Port.DIRECTION_OUTPUT, + type=models.Port.TYPE_FILE, + active=True, + ), + ) + for station in stations + ] + ) + ) models.Port.objects.bulk_create(items) return items @@ -180,3 +213,50 @@ def metadata_string(metadata_data): "\n".join(f"{key}={value}" for key, value in metadata_data.items()) + "\nEND" ) + + +# -- streamers +class FakeStreamer(controllers.Streamer): + calls = {} + + def fetch(self): + self.calls["fetch"] = True + + +class FakeSource(controllers.Source): + def __init__(self, id, *args, **kwargs): + self.id = id + self.args = args + self.kwargs = kwargs + self.calls = {} + + def fetch(self): + self.calls["sync"] = True + + def sync(self): + self.calls["sync"] = True + + def push(self, path): + self.calls["push"] = path + return path + + def skip(self): + self.calls["skip"] = True + + def restart(self): + self.calls["restart"] = True + + def seek(self, c): + self.calls["seek"] = c + + +@pytest.fixture +def streamers(stations, stations_ports): + streamers = controllers.Streamers(streamer_class=FakeStreamer) + # avoid unecessary db calls + streamers.streamers = { + station.pk: FakeStreamer(station) for station in stations + } + for streamer in streamers.values(): + streamer.sources = [FakeSource(i) for i in range(0, 3)] + return streamers diff --git a/aircox_streamer/tests/test_controllers_streamers.py b/aircox_streamer/tests/test_controllers_streamers.py new file mode 100644 index 0000000..b488043 --- /dev/null +++ b/aircox_streamer/tests/test_controllers_streamers.py @@ -0,0 +1,37 @@ +from datetime import timedelta + +from django.utils import timezone as tz +import pytest + + +class TestStreamers: + @pytest.fixture + def test___init__(self, streamers): + assert isinstance(streamers.timeout, timedelta) + + @pytest.fixture + def test_reset(self, streamers, stations): + streamers.reset() + assert all( + streamers.streamers[station.pk] == station for station in stations + ) + + @pytest.fixture + def test_fetch(self, streamers): + streamers.next_date = tz.now() - tz.timedelta(seconds=30) + streamers.streamers = None + + now = tz.now() + streamers.fetch() + + assert all(streamer.calls.get("fetch") for streamer in streamers) + assert streamers.next_date > now + + @pytest.fixture + def test_fetch_timeout_not_reached(self, streamers): + next_date = tz.now() + tz.timedelta(seconds=30) + streamers.next_date = next_date + + streamers.fetch() + assert all(not streamer.calls.get("fetch") for streamer in streamers) + assert streamers.next_date == next_date diff --git a/aircox_streamer/tests/test_viewsets.py b/aircox_streamer/tests/test_viewsets.py new file mode 100644 index 0000000..5493aaa --- /dev/null +++ b/aircox_streamer/tests/test_viewsets.py @@ -0,0 +1,185 @@ +import pytest + +from django.http import Http404 + +from rest_framework.exceptions import ValidationError +from aircox_streamer.viewsets import ( + ControllerViewSet, + SourceViewSet, + StreamerViewSet, + QueueSourceViewSet, +) + + +class FakeSerializer: + def __init__(self, instance, *args, **kwargs): + self.instance = instance + self.data = {"instance": self.instance} + self.args = args + self.kwargs = kwargs + + +class FakeRequest: + def __init__(self, **kwargs): + self.__dict__.update(**kwargs) + + +@pytest.fixture +def controller_viewset(streamers, station): + return ControllerViewSet( + streamers=streamers, + streamer=streamers[station.pk], + serializer_class=FakeSerializer, + ) + + +@pytest.fixture +def streamer_viewset(streamers, station): + return StreamerViewSet( + streamers=streamers, + streamer=streamers[station.pk], + serializer_class=FakeSerializer, + ) + + +@pytest.fixture +def source_viewset(streamers, station): + return SourceViewSet( + streamers=streamers, + streamer=streamers[station.pk], + serializer_class=FakeSerializer, + ) + + +@pytest.fixture +def queue_source_viewset(streamers, station): + return QueueSourceViewSet( + streamers=streamers, + streamer=streamers[station.pk], + serializer_class=FakeSerializer, + ) + + +class TestControllerViewSet: + @pytest.mark.django_db + def test_get_streamer(self, controller_viewset, stations): + station = stations[0] + streamer = controller_viewset.get_streamer(station.pk) + assert streamer.station.pk == station.pk + assert streamer.calls.get("fetch") + + @pytest.mark.django_db + def test_get_streamer_station_not_found(self, controller_viewset): + controller_viewset.streamers.streamers = {} + with pytest.raises(Http404): + controller_viewset.get_streamer(1) + + @pytest.mark.django_db + def test_get_serializer(self, controller_viewset): + controller_viewset.object = {"object": "value"} + serializer = controller_viewset.get_serializer(test=True) + assert serializer.kwargs["test"] + assert serializer.instance == controller_viewset.object + + @pytest.mark.django_db + def test_serialize(self, controller_viewset): + instance = {} + data = controller_viewset.serialize(instance, test=True) + assert data == {"instance": instance} + + +class TestStreamerViewSet: + @pytest.mark.django_db + def test_retrieve(self, streamer_viewset): + streamer_viewset.streamer = {"streamer": "test"} + resp = streamer_viewset.retrieve(None, None) + assert resp.data == {"instance": streamer_viewset.streamer} + + @pytest.mark.django_db + def test_list(self, streamer_viewset): + streamers = {"a": 1, "b": 2} + streamer_viewset.streamers.streamers = streamers + resp = streamer_viewset.list(None) + assert set(resp.data["results"]["instance"]) == set(streamers.values()) + + +class TestSourceViewSet: + @pytest.mark.django_db + def test_get_sources(self, source_viewset, streamers): + source_viewset.streamer.sources.append(45) + sources = source_viewset.get_sources() + assert 45 not in set(sources) + + @pytest.mark.django_db + def test_get_source(self, source_viewset): + source = source_viewset.get_source(1) + assert source.id == 1 + + @pytest.mark.django_db + def test_get_source_not_found(self, source_viewset): + with pytest.raises(Http404): + source_viewset.get_source(1000) + + @pytest.mark.django_db + def test_retrieve(self, source_viewset, station): + resp = source_viewset.retrieve(None, 0) + source = source_viewset.streamers[station.pk].sources[0] + # this is FakeSerializer being used which provides us the proof + assert resp.data["instance"] == source + + @pytest.mark.django_db + def test_list(self, source_viewset, station): + sources = source_viewset.streamers[station.pk].sources + resp = source_viewset.list(None) + assert list(resp.data["results"]["instance"]) == sources + + @pytest.mark.django_db + def test__run(self, source_viewset): + calls = {} + + def action(x): + return calls.setdefault("action", True) + + source_viewset._run(0, action) + assert calls.get("action") + + @pytest.mark.django_db + def test_all_api_source_actions(self, source_viewset, station): + source = source_viewset.streamers[station.pk].sources[0] + request = FakeRequest(POST={"seek": 1}) + source_viewset.get_source = lambda x: source + + for action in ("sync", "skip", "restart", "seek"): + func = getattr(source_viewset, action) + func(request, 1) + assert source.calls.get(action) + + +class TestQueueSourceViewSet: + @pytest.mark.django_db + def test_get_sound_queryset(self, queue_source_viewset, station, sounds): + ids = {sound.pk for sound in sounds} + request = FakeRequest(station=station) + query = queue_source_viewset.get_sound_queryset(request) + assert set(query.values_list("pk", flat=True)) == ids + + @pytest.mark.django_db + def test_push(self, queue_source_viewset, station, sounds): + calls = {} + sound = sounds[0] + request = FakeRequest(station=station, data={"sound_id": sound.pk}) + queue_source_viewset._run = lambda pk, func: calls.setdefault( + "_run", (pk, func) + ) + result = queue_source_viewset.push(request, 13) + assert "_run" in calls + assert result[0] == 13 + assert callable(result[1]) + + @pytest.mark.django_db + def test_push_missing_sound_in_request_post( + self, queue_source_viewset, station + ): + request = FakeRequest(station=station, data={}) + with pytest.raises(ValidationError): + queue_source_viewset.push(request, 0) diff --git a/aircox_streamer/urls.py b/aircox_streamer/urls.py index 2a159b6..fdd86ec 100644 --- a/aircox_streamer/urls.py +++ b/aircox_streamer/urls.py @@ -4,11 +4,11 @@ from django.utils.translation import gettext_lazy as _ from aircox.viewsets import SoundViewSet from . import viewsets -from .views import StreamerAdminMixin +from .views import StreamerAdminView admin.site.route_view( "tools/streamer", - StreamerAdminMixin.as_view(), + StreamerAdminView.as_view(), "tools-streamer", label=_("Streamer Monitor"), ) diff --git a/aircox_streamer/views.py b/aircox_streamer/views.py index 03df665..43832e0 100644 --- a/aircox_streamer/views.py +++ b/aircox_streamer/views.py @@ -2,8 +2,17 @@ from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView from aircox.views.admin import AdminMixin +from .controllers import streamers -class StreamerAdminMixin(AdminMixin, TemplateView): +class StreamerAdminView(AdminMixin, TemplateView): template_name = "aircox_streamer/streamer.html" title = _("Streamer Monitor") + streamers = streamers + + def dispatch(self, *args, **kwargs): + # Note: this might raise concurrency racing problem with viewsets, + # since streamers.streamers is reset to a new dict. Still am i not + # sure, and needs analysis. + self.streamers.reset() + return super().dispatch(*args, **kwargs) diff --git a/aircox_streamer/viewsets.py b/aircox_streamer/viewsets.py index 1a11be9..bd63567 100644 --- a/aircox_streamer/viewsets.py +++ b/aircox_streamer/viewsets.py @@ -1,15 +1,15 @@ from django.http import Http404 from django.shortcuts import get_object_or_404 -from django.utils import timezone as tz from rest_framework import viewsets from rest_framework.decorators import action from rest_framework.exceptions import ValidationError from rest_framework.permissions import IsAdminUser from rest_framework.response import Response -from aircox.models import Sound, Station +from aircox.models import Sound from . import controllers + from .serializers import ( PlaylistSerializer, QueueSourceSerializer, @@ -19,8 +19,7 @@ from .serializers import ( ) __all__ = [ - "Streamers", - "BaseControllerAPIView", + "ControllerViewSet", "RequestViewSet", "StreamerViewSet", "SourceViewSet", @@ -29,94 +28,45 @@ __all__ = [ ] -class Streamers: - date = None - """Next update datetime.""" - streamers = None - """Stations by station id.""" - timeout = None - """Timedelta to next update.""" - - def __init__(self, timeout=None): - self.timeout = timeout or tz.timedelta(seconds=2) - - def load(self, force=False): - # FIXME: cf. TODO in aircox.controllers about model updates - stations = Station.objects.active() - if self.streamers is None or force: - self.streamers = { - station.pk: controllers.Streamer(station) - for station in stations - } - return - - streamers = self.streamers - self.streamers = { - station.pk: controllers.Streamer(station) - if station.pk in streamers - else streamers[station.pk] - for station in stations - } - - def fetch(self): - if self.streamers is None: - self.load() - - now = tz.now() - if self.date is not None and now < self.date: - return - - for streamer in self.streamers.values(): - streamer.fetch() - self.date = now + self.timeout - - def get(self, key, default=None): - return self.streamers.get(key, default) - - def values(self): - return self.streamers.values() - - def __getitem__(self, key): - return self.streamers[key] - - def __contains__(self, key): - return key in self.streamers - - -streamers = Streamers() - - -class BaseControllerAPIView(viewsets.ViewSet): +class ControllerViewSet(viewsets.ViewSet): permission_classes = (IsAdminUser,) serializer_class = None + streamers = controllers.streamers + """Streamers controller instance.""" streamer = None + """User's Streamer instance.""" object = None + """Object to serialize.""" - def get_streamer(self, request, station_pk=None, **kwargs): - streamers.fetch() - id = int(request.station.pk if station_pk is None else station_pk) - if id not in streamers: + def get_streamer(self, station_pk=None): + """Get user's streamer.""" + if station_pk is None: + station_pk = self.request.station.pk + self.streamers.fetch() + if station_pk not in self.streamers: raise Http404("station not found") - return streamers[id] + return self.streamers[station_pk] def get_serializer(self, **kwargs): + """Get serializer instance.""" return self.serializer_class(self.object, **kwargs) def serialize(self, obj, **kwargs): + """Serializer controller data.""" self.object = obj serializer = self.get_serializer(**kwargs) return serializer.data def dispatch(self, request, *args, station_pk=None, **kwargs): - self.streamer = self.get_streamer(request, station_pk, **kwargs) + self.streamer = self.get_streamer(station_pk) return super().dispatch(request, *args, **kwargs) -class RequestViewSet(BaseControllerAPIView): +class RequestViewSet(ControllerViewSet): serializer_class = RequestSerializer -class StreamerViewSet(BaseControllerAPIView): +class StreamerViewSet(ControllerViewSet): serializer_class = StreamerSerializer def retrieve(self, request, pk=None): @@ -124,7 +74,7 @@ class StreamerViewSet(BaseControllerAPIView): def list(self, request, pk=None): return Response( - {"results": self.serialize(streamers.values(), many=True)} + {"results": self.serialize(self.streamers.values(), many=True)} ) def dispatch(self, request, *args, pk=None, **kwargs): @@ -135,7 +85,7 @@ class StreamerViewSet(BaseControllerAPIView): return super().dispatch(request, *args, **kwargs) -class SourceViewSet(BaseControllerAPIView): +class SourceViewSet(ControllerViewSet): serializer_class = SourceSerializer model = controllers.Source @@ -151,8 +101,8 @@ class SourceViewSet(BaseControllerAPIView): return source def retrieve(self, request, pk=None): - self.object = self.get_source(pk) - return Response(self.serialize()) + source = self.get_source(pk) + return Response(self.serialize(source)) def list(self, request): return Response( @@ -192,8 +142,8 @@ class QueueSourceViewSet(SourceViewSet): serializer_class = QueueSourceSerializer model = controllers.QueueSource - def get_sound_queryset(self): - return Sound.objects.station(self.request.station).archive() + def get_sound_queryset(self, request): + return Sound.objects.station(request.station).archive() @action(detail=True, methods=["POST"]) def push(self, request, pk): @@ -201,7 +151,7 @@ class QueueSourceViewSet(SourceViewSet): raise ValidationError('missing "sound_id" POST data') sound = get_object_or_404( - self.get_sound_queryset(), pk=request.data["sound_id"] + self.get_sound_queryset(request), pk=request.data["sound_id"] ) return self._run( pk, lambda s: s.push(sound.file.path) if sound.file.path else None