aircox/aircox_streamer/tests/test_viewsets.py
2023-10-11 10:53:42 +02:00

182 lines
5.7 KiB
Python

import pytest
from django.http import Http404
from rest_framework.exceptions import ValidationError
from aircox_streamer.viewsets import (
ControllerViewSet,
SourceViewSet,
StreamerViewSet,
QueueSourceViewSet,
)
class FakeSerializer:
def __init__(self, instance, *args, **kwargs):
self.instance = instance
self.data = {"instance": self.instance}
self.args = args
self.kwargs = kwargs
class FakeRequest:
def __init__(self, **kwargs):
self.__dict__.update(**kwargs)
@pytest.fixture
def controller_viewset(streamers, station):
return ControllerViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
@pytest.fixture
def streamer_viewset(streamers, station):
return StreamerViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
@pytest.fixture
def source_viewset(streamers, station):
return SourceViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
@pytest.fixture
def queue_source_viewset(streamers, station):
return QueueSourceViewSet(
streamers=streamers,
streamer=streamers[station.pk],
serializer_class=FakeSerializer,
)
class TestControllerViewSet:
@pytest.mark.django_db
def test_get_streamer(self, controller_viewset, stations):
station = stations[0]
streamer = controller_viewset.get_streamer(station.pk)
assert streamer.station.pk == station.pk
assert streamer.calls.get("fetch")
@pytest.mark.django_db
def test_get_streamer_station_not_found(self, controller_viewset):
controller_viewset.streamers.streamers = {}
with pytest.raises(Http404):
controller_viewset.get_streamer(1)
@pytest.mark.django_db
def test_get_serializer(self, controller_viewset):
controller_viewset.object = {"object": "value"}
serializer = controller_viewset.get_serializer(test=True)
assert serializer.kwargs["test"]
assert serializer.instance == controller_viewset.object
@pytest.mark.django_db
def test_serialize(self, controller_viewset):
instance = {}
data = controller_viewset.serialize(instance, test=True)
assert data == {"instance": instance}
class TestStreamerViewSet:
@pytest.mark.django_db
def test_retrieve(self, streamer_viewset):
streamer_viewset.streamer = {"streamer": "test"}
resp = streamer_viewset.retrieve(None, None)
assert resp.data == {"instance": streamer_viewset.streamer}
@pytest.mark.django_db
def test_list(self, streamer_viewset):
streamers = {"a": 1, "b": 2}
streamer_viewset.streamers.streamers = streamers
resp = streamer_viewset.list(None)
assert set(resp.data["results"]["instance"]) == set(streamers.values())
class TestSourceViewSet:
@pytest.mark.django_db
def test_get_sources(self, source_viewset, streamers):
source_viewset.streamer.sources.append(45)
sources = source_viewset.get_sources()
assert 45 not in set(sources)
@pytest.mark.django_db
def test_get_source(self, source_viewset):
source = source_viewset.get_source(1)
assert source.id == 1
@pytest.mark.django_db
def test_get_source_not_found(self, source_viewset):
with pytest.raises(Http404):
source_viewset.get_source(1000)
@pytest.mark.django_db
def test_retrieve(self, source_viewset, station):
resp = source_viewset.retrieve(None, 0)
source = source_viewset.streamers[station.pk].sources[0]
# this is FakeSerializer being used which provides us the proof
assert resp.data["instance"] == source
@pytest.mark.django_db
def test_list(self, source_viewset, station):
sources = source_viewset.streamers[station.pk].sources
resp = source_viewset.list(None)
assert list(resp.data["results"]["instance"]) == sources
@pytest.mark.django_db
def test__run(self, source_viewset):
calls = {}
def action(x):
return calls.setdefault("action", True)
source_viewset._run(0, action)
assert calls.get("action")
@pytest.mark.django_db
def test_all_api_source_actions(self, source_viewset, station):
source = source_viewset.streamers[station.pk].sources[0]
request = FakeRequest(POST={"seek": 1})
source_viewset.get_source = lambda x: source
for action in ("sync", "skip", "restart", "seek"):
func = getattr(source_viewset, action)
func(request, 1)
assert source.calls.get(action)
class TestQueueSourceViewSet:
@pytest.mark.django_db
def test_get_sound_queryset(self, queue_source_viewset, station, sounds):
ids = {sound.pk for sound in sounds}
request = FakeRequest(station=station)
query = queue_source_viewset.get_sound_queryset(request)
assert set(query.values_list("pk", flat=True)) == ids
@pytest.mark.django_db
def test_push(self, queue_source_viewset, station, sounds):
calls = {}
sound = sounds[0]
request = FakeRequest(station=station, data={"sound_id": sound.pk})
queue_source_viewset._run = lambda pk, func: calls.setdefault("_run", (pk, func))
result = queue_source_viewset.push(request, 13)
assert "_run" in calls
assert result[0] == 13
assert callable(result[1])
@pytest.mark.django_db
def test_push_missing_sound_in_request_post(self, queue_source_viewset, station):
request = FakeRequest(station=station, data={})
with pytest.raises(ValidationError):
queue_source_viewset.push(request, 0)