#106: tests: aircox_streamer #110

Merged
thomas merged 8 commits from dev-1.0-106-test-aircox-streamer into develop-1.0 2023-06-18 15:00:09 +00:00
7 changed files with 360 additions and 93 deletions
Showing only changes of commit a7f39c3628 - Show all commits

View File

@ -3,13 +3,19 @@
from .metadata import Metadata, Request from .metadata import Metadata, Request
from .streamer import Streamer from .streamer import Streamer
from .streamers import Streamers
from .sources import Source, PlaylistSource, QueueSource from .sources import Source, PlaylistSource, QueueSource
streamers = Streamers()
"""Default controller used by views and viewsets."""
__all__ = ( __all__ = (
"Metadata", "Metadata",
"Request", "Request",
"Streamer", "Streamer",
"Streamers",
"Source", "Source",
"PlaylistSource", "PlaylistSource",
"QueueSource", "QueueSource",

View File

@ -1,3 +1,4 @@
import itertools
import os import os
from datetime import datetime, time from datetime import datetime, time
@ -74,21 +75,53 @@ def station():
@pytest.fixture @pytest.fixture
def station_ports(station): def stations(station):
items = [ objs = [
models.Port( models.Station(
station=station, name=f"test-{i}",
direction=models.Port.DIRECTION_INPUT, slug=f"test-{i}",
type=models.Port.TYPE_HTTP, path=working_dir,
default=(i == 0),
active=True, active=True,
), )
models.Port( for i in range(0, 3)
station=station,
direction=models.Port.DIRECTION_OUTPUT,
type=models.Port.TYPE_FILE,
active=True,
),
] ]
models.Station.objects.bulk_create(objs)
return [station] + objs
@pytest.fixture
def station_ports(station):
return _stations_ports(station)
@pytest.fixture
def stations_ports(stations):
return _stations_ports(*stations)
def _stations_ports(*stations):
items = list(
itertools.chain(
*[
(
models.Port(
station=station,
direction=models.Port.DIRECTION_INPUT,
type=models.Port.TYPE_HTTP,
active=True,
),
models.Port(
station=station,
direction=models.Port.DIRECTION_OUTPUT,
type=models.Port.TYPE_FILE,
active=True,
),
)
for station in stations
]
)
)
models.Port.objects.bulk_create(items) models.Port.objects.bulk_create(items)
return items return items
@ -180,3 +213,50 @@ def metadata_string(metadata_data):
"\n".join(f"{key}={value}" for key, value in metadata_data.items()) "\n".join(f"{key}={value}" for key, value in metadata_data.items())
+ "\nEND" + "\nEND"
) )
# -- streamers
class FakeStreamer(controllers.Streamer):
calls = {}
def fetch(self):
self.calls["fetch"] = True
class FakeSource(controllers.Source):
def __init__(self, id, *args, **kwargs):
self.id = id
self.args = args
self.kwargs = kwargs
self.calls = {}
def fetch(self):
self.calls["sync"] = True
def sync(self):
self.calls["sync"] = True
def push(self, path):
self.calls["push"] = path
return path
def skip(self):
self.calls["skip"] = True
def restart(self):
self.calls["restart"] = True
def seek(self, c):
self.calls["seek"] = c
@pytest.fixture
def streamers(stations, stations_ports):
streamers = controllers.Streamers(streamer_class=FakeStreamer)
# avoid unecessary db calls
streamers.streamers = {
station.pk: FakeStreamer(station) for station in stations
}
for streamer in streamers.values():
streamer.sources = [FakeSource(i) for i in range(0, 3)]
return streamers

View File

@ -0,0 +1,37 @@
from datetime import timedelta
from django.utils import timezone as tz
import pytest
class TestStreamers:
@pytest.fixture
def test___init__(self, streamers):
assert isinstance(streamers.timeout, timedelta)
@pytest.fixture
def test_reset(self, streamers, stations):
streamers.reset()
assert all(
streamers.streamers[station.pk] == station for station in stations
)
@pytest.fixture
def test_fetch(self, streamers):
streamers.next_date = tz.now() - tz.timedelta(seconds=30)
streamers.streamers = None
now = tz.now()
streamers.fetch()
assert all(streamer.calls.get("fetch") for streamer in streamers)
assert streamers.next_date > now
@pytest.fixture
def test_fetch_timeout_not_reached(self, streamers):
next_date = tz.now() + tz.timedelta(seconds=30)
streamers.next_date = next_date
streamers.fetch()
assert all(not streamer.calls.get("fetch") for streamer in streamers)
assert streamers.next_date == next_date

View File

@ -0,0 +1,185 @@
import pytest
from django.http import Http404
from rest_framework.exceptions import ValidationError
from aircox_streamer.viewsets import (
ControllerViewSet,
SourceViewSet,
StreamerViewSet,
QueueSourceViewSet,
)
class FakeSerializer:
def __init__(self, instance, *args, **kwargs):
self.instance = instance
self.data = {"instance": self.instance}
self.args = args
self.kwargs = kwargs
class FakeRequest:
def __init__(self, **kwargs):
self.__dict__.update(**kwargs)
@pytest.fixture
def controller_viewset(streamers, station):
return ControllerViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
@pytest.fixture
def streamer_viewset(streamers, station):
return StreamerViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
@pytest.fixture
def source_viewset(streamers, station):
return SourceViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
@pytest.fixture
def queue_source_viewset(streamers, station):
return QueueSourceViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
class TestControllerViewSet:
@pytest.mark.django_db
def test_get_streamer(self, controller_viewset, stations):
station = stations[0]
streamer = controller_viewset.get_streamer(station.pk)
assert streamer.station.pk == station.pk
assert streamer.calls.get("fetch")
@pytest.mark.django_db
def test_get_streamer_station_not_found(self, controller_viewset):
controller_viewset.streamers.streamers = {}
with pytest.raises(Http404):
controller_viewset.get_streamer(1)
@pytest.mark.django_db
def test_get_serializer(self, controller_viewset):
controller_viewset.object = {"object": "value"}
serializer = controller_viewset.get_serializer(test=True)
assert serializer.kwargs["test"]
assert serializer.instance == controller_viewset.object
@pytest.mark.django_db
def test_serialize(self, controller_viewset):
instance = {}
data = controller_viewset.serialize(instance, test=True)
assert data == {"instance": instance}
class TestStreamerViewSet:
@pytest.mark.django_db
def test_retrieve(self, streamer_viewset):
streamer_viewset.streamer = {"streamer": "test"}
resp = streamer_viewset.retrieve(None, None)
assert resp.data == {"instance": streamer_viewset.streamer}
@pytest.mark.django_db
def test_list(self, streamer_viewset):
streamers = {"a": 1, "b": 2}
streamer_viewset.streamers.streamers = streamers
resp = streamer_viewset.list(None)
assert set(resp.data["results"]["instance"]) == set(streamers.values())
class TestSourceViewSet:
@pytest.mark.django_db
def test_get_sources(self, source_viewset, streamers):
source_viewset.streamer.sources.append(45)
sources = source_viewset.get_sources()
assert 45 not in set(sources)
@pytest.mark.django_db
def test_get_source(self, source_viewset):
source = source_viewset.get_source(1)
assert source.id == 1
@pytest.mark.django_db
def test_get_source_not_found(self, source_viewset):
with pytest.raises(Http404):
source_viewset.get_source(1000)
@pytest.mark.django_db
def test_retrieve(self, source_viewset, station):
resp = source_viewset.retrieve(None, 0)
source = source_viewset.streamers[station.pk].sources[0]
# this is FakeSerializer being used which provides us the proof
assert resp.data["instance"] == source
@pytest.mark.django_db
def test_list(self, source_viewset, station):
sources = source_viewset.streamers[station.pk].sources
resp = source_viewset.list(None)
assert list(resp.data["results"]["instance"]) == sources
@pytest.mark.django_db
def test__run(self, source_viewset):
calls = {}
def action(x):
return calls.setdefault("action", True)
source_viewset._run(0, action)
assert calls.get("action")
@pytest.mark.django_db
def test_all_api_source_actions(self, source_viewset, station):
source = source_viewset.streamers[station.pk].sources[0]
request = FakeRequest(POST={"seek": 1})
source_viewset.get_source = lambda x: source
for action in ("sync", "skip", "restart", "seek"):
func = getattr(source_viewset, action)
func(request, 1)
assert source.calls.get(action)
class TestQueueSourceViewSet:
@pytest.mark.django_db
def test_get_sound_queryset(self, queue_source_viewset, station, sounds):
ids = {sound.pk for sound in sounds}
request = FakeRequest(station=station)
query = queue_source_viewset.get_sound_queryset(request)
assert set(query.values_list("pk", flat=True)) == ids
@pytest.mark.django_db
def test_push(self, queue_source_viewset, station, sounds):
calls = {}
sound = sounds[0]
request = FakeRequest(station=station, data={"sound_id": sound.pk})
queue_source_viewset._run = lambda pk, func: calls.setdefault(
"_run", (pk, func)
)
result = queue_source_viewset.push(request, 13)
assert "_run" in calls
assert result[0] == 13
assert callable(result[1])
@pytest.mark.django_db
def test_push_missing_sound_in_request_post(
self, queue_source_viewset, station
):
request = FakeRequest(station=station, data={})
with pytest.raises(ValidationError):
queue_source_viewset.push(request, 0)

View File

@ -4,11 +4,11 @@ from django.utils.translation import gettext_lazy as _
from aircox.viewsets import SoundViewSet from aircox.viewsets import SoundViewSet
from . import viewsets from . import viewsets
from .views import StreamerAdminMixin from .views import StreamerAdminView
admin.site.route_view( admin.site.route_view(
"tools/streamer", "tools/streamer",
StreamerAdminMixin.as_view(), StreamerAdminView.as_view(),
"tools-streamer", "tools-streamer",
label=_("Streamer Monitor"), label=_("Streamer Monitor"),
) )

View File

@ -2,8 +2,17 @@ from django.utils.translation import gettext_lazy as _
from django.views.generic import TemplateView from django.views.generic import TemplateView
from aircox.views.admin import AdminMixin from aircox.views.admin import AdminMixin
from .controllers import streamers
class StreamerAdminMixin(AdminMixin, TemplateView): class StreamerAdminView(AdminMixin, TemplateView):
template_name = "aircox_streamer/streamer.html" template_name = "aircox_streamer/streamer.html"
title = _("Streamer Monitor") title = _("Streamer Monitor")
streamers = streamers
def dispatch(self, *args, **kwargs):
# Note: this might raise concurrency racing problem with viewsets,
# since streamers.streamers is reset to a new dict. Still am i not
# sure, and needs analysis.
self.streamers.reset()
return super().dispatch(*args, **kwargs)

View File

@ -1,15 +1,15 @@
from django.http import Http404 from django.http import Http404
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils import timezone as tz
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAdminUser from rest_framework.permissions import IsAdminUser
from rest_framework.response import Response from rest_framework.response import Response
from aircox.models import Sound, Station from aircox.models import Sound
from . import controllers from . import controllers
from .serializers import ( from .serializers import (
PlaylistSerializer, PlaylistSerializer,
QueueSourceSerializer, QueueSourceSerializer,
@ -19,8 +19,7 @@ from .serializers import (
) )
__all__ = [ __all__ = [
"Streamers", "ControllerViewSet",
"BaseControllerAPIView",
"RequestViewSet", "RequestViewSet",
"StreamerViewSet", "StreamerViewSet",
"SourceViewSet", "SourceViewSet",
@ -29,94 +28,45 @@ __all__ = [
] ]
class Streamers: class ControllerViewSet(viewsets.ViewSet):
date = None
"""Next update datetime."""
streamers = None
"""Stations by station id."""
timeout = None
"""Timedelta to next update."""
def __init__(self, timeout=None):
self.timeout = timeout or tz.timedelta(seconds=2)
def load(self, force=False):
# FIXME: cf. TODO in aircox.controllers about model updates
stations = Station.objects.active()
if self.streamers is None or force:
self.streamers = {
station.pk: controllers.Streamer(station)
for station in stations
}
return
streamers = self.streamers
self.streamers = {
station.pk: controllers.Streamer(station)
if station.pk in streamers
else streamers[station.pk]
for station in stations
}
def fetch(self):
if self.streamers is None:
self.load()
now = tz.now()
if self.date is not None and now < self.date:
return
for streamer in self.streamers.values():
streamer.fetch()
self.date = now + self.timeout
def get(self, key, default=None):
return self.streamers.get(key, default)
def values(self):
return self.streamers.values()
def __getitem__(self, key):
return self.streamers[key]
def __contains__(self, key):
return key in self.streamers
streamers = Streamers()
class BaseControllerAPIView(viewsets.ViewSet):
permission_classes = (IsAdminUser,) permission_classes = (IsAdminUser,)
serializer_class = None serializer_class = None
streamers = controllers.streamers
"""Streamers controller instance."""
streamer = None streamer = None
"""User's Streamer instance."""
object = None object = None
"""Object to serialize."""
def get_streamer(self, request, station_pk=None, **kwargs): def get_streamer(self, station_pk=None):
streamers.fetch() """Get user's streamer."""
id = int(request.station.pk if station_pk is None else station_pk) if station_pk is None:
if id not in streamers: station_pk = self.request.station.pk
self.streamers.fetch()
if station_pk not in self.streamers:
raise Http404("station not found") raise Http404("station not found")
return streamers[id] return self.streamers[station_pk]
def get_serializer(self, **kwargs): def get_serializer(self, **kwargs):
"""Get serializer instance."""
return self.serializer_class(self.object, **kwargs) return self.serializer_class(self.object, **kwargs)
def serialize(self, obj, **kwargs): def serialize(self, obj, **kwargs):
"""Serializer controller data."""
self.object = obj self.object = obj
serializer = self.get_serializer(**kwargs) serializer = self.get_serializer(**kwargs)
return serializer.data return serializer.data
def dispatch(self, request, *args, station_pk=None, **kwargs): def dispatch(self, request, *args, station_pk=None, **kwargs):
self.streamer = self.get_streamer(request, station_pk, **kwargs) self.streamer = self.get_streamer(station_pk)
return super().dispatch(request, *args, **kwargs) return super().dispatch(request, *args, **kwargs)
class RequestViewSet(BaseControllerAPIView): class RequestViewSet(ControllerViewSet):
serializer_class = RequestSerializer serializer_class = RequestSerializer
class StreamerViewSet(BaseControllerAPIView): class StreamerViewSet(ControllerViewSet):
serializer_class = StreamerSerializer serializer_class = StreamerSerializer
def retrieve(self, request, pk=None): def retrieve(self, request, pk=None):
@ -124,7 +74,7 @@ class StreamerViewSet(BaseControllerAPIView):
def list(self, request, pk=None): def list(self, request, pk=None):
return Response( return Response(
{"results": self.serialize(streamers.values(), many=True)} {"results": self.serialize(self.streamers.values(), many=True)}
) )
def dispatch(self, request, *args, pk=None, **kwargs): def dispatch(self, request, *args, pk=None, **kwargs):
@ -135,7 +85,7 @@ class StreamerViewSet(BaseControllerAPIView):
return super().dispatch(request, *args, **kwargs) return super().dispatch(request, *args, **kwargs)
class SourceViewSet(BaseControllerAPIView): class SourceViewSet(ControllerViewSet):
serializer_class = SourceSerializer serializer_class = SourceSerializer
model = controllers.Source model = controllers.Source
@ -151,8 +101,8 @@ class SourceViewSet(BaseControllerAPIView):
return source return source
def retrieve(self, request, pk=None): def retrieve(self, request, pk=None):
self.object = self.get_source(pk) source = self.get_source(pk)
return Response(self.serialize()) return Response(self.serialize(source))
def list(self, request): def list(self, request):
return Response( return Response(
@ -192,8 +142,8 @@ class QueueSourceViewSet(SourceViewSet):
serializer_class = QueueSourceSerializer serializer_class = QueueSourceSerializer
model = controllers.QueueSource model = controllers.QueueSource
def get_sound_queryset(self): def get_sound_queryset(self, request):
return Sound.objects.station(self.request.station).archive() return Sound.objects.station(request.station).archive()
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
def push(self, request, pk): def push(self, request, pk):
@ -201,7 +151,7 @@ class QueueSourceViewSet(SourceViewSet):
raise ValidationError('missing "sound_id" POST data') raise ValidationError('missing "sound_id" POST data')
sound = get_object_or_404( sound = get_object_or_404(
self.get_sound_queryset(), pk=request.data["sound_id"] self.get_sound_queryset(request), pk=request.data["sound_id"]
) )
return self._run( return self._run(
pk, lambda s: s.push(sound.file.path) if sound.file.path else None pk, lambda s: s.push(sound.file.path) if sound.file.path else None