import pytest from django.http import Http404 from rest_framework.exceptions import ValidationError from aircox_streamer.viewsets import ( ControllerViewSet, SourceViewSet, StreamerViewSet, QueueSourceViewSet, ) class FakeSerializer: def __init__(self, instance, *args, **kwargs): self.instance = instance self.data = {"instance": self.instance} self.args = args self.kwargs = kwargs class FakeRequest: def __init__(self, **kwargs): self.__dict__.update(**kwargs) @pytest.fixture def controller_viewset(streamers, station): return ControllerViewSet( streamers=streamers, streamer=streamers[station.pk], serializer_class=FakeSerializer, ) @pytest.fixture def streamer_viewset(streamers, station): return StreamerViewSet( streamers=streamers, streamer=streamers[station.pk], serializer_class=FakeSerializer, ) @pytest.fixture def source_viewset(streamers, station): return SourceViewSet( streamers=streamers, streamer=streamers[station.pk], serializer_class=FakeSerializer, ) @pytest.fixture def queue_source_viewset(streamers, station): return QueueSourceViewSet( streamers=streamers, streamer=streamers[station.pk], serializer_class=FakeSerializer, ) class TestControllerViewSet: @pytest.mark.django_db def test_get_streamer(self, controller_viewset, stations): station = stations[0] streamer = controller_viewset.get_streamer(station.pk) assert streamer.station.pk == station.pk assert streamer.calls.get("fetch") @pytest.mark.django_db def test_get_streamer_station_not_found(self, controller_viewset): controller_viewset.streamers.streamers = {} with pytest.raises(Http404): controller_viewset.get_streamer(1) @pytest.mark.django_db def test_get_serializer(self, controller_viewset): controller_viewset.object = {"object": "value"} serializer = controller_viewset.get_serializer(test=True) assert serializer.kwargs["test"] assert serializer.instance == controller_viewset.object @pytest.mark.django_db def test_serialize(self, controller_viewset): instance = {} data = controller_viewset.serialize(instance, test=True) assert data == {"instance": instance} class TestStreamerViewSet: @pytest.mark.django_db def test_retrieve(self, streamer_viewset): streamer_viewset.streamer = {"streamer": "test"} resp = streamer_viewset.retrieve(None, None) assert resp.data == {"instance": streamer_viewset.streamer} @pytest.mark.django_db def test_list(self, streamer_viewset): streamers = {"a": 1, "b": 2} streamer_viewset.streamers.streamers = streamers resp = streamer_viewset.list(None) assert set(resp.data["results"]["instance"]) == set(streamers.values()) class TestSourceViewSet: @pytest.mark.django_db def test_get_sources(self, source_viewset, streamers): source_viewset.streamer.sources.append(45) sources = source_viewset.get_sources() assert 45 not in set(sources) @pytest.mark.django_db def test_get_source(self, source_viewset): source = source_viewset.get_source(1) assert source.id == 1 @pytest.mark.django_db def test_get_source_not_found(self, source_viewset): with pytest.raises(Http404): source_viewset.get_source(1000) @pytest.mark.django_db def test_retrieve(self, source_viewset, station): resp = source_viewset.retrieve(None, 0) source = source_viewset.streamers[station.pk].sources[0] # this is FakeSerializer being used which provides us the proof assert resp.data["instance"] == source @pytest.mark.django_db def test_list(self, source_viewset, station): sources = source_viewset.streamers[station.pk].sources resp = source_viewset.list(None) assert list(resp.data["results"]["instance"]) == sources @pytest.mark.django_db def test__run(self, source_viewset): calls = {} def action(x): return calls.setdefault("action", True) source_viewset._run(0, action) assert calls.get("action") @pytest.mark.django_db def test_all_api_source_actions(self, source_viewset, station): source = source_viewset.streamers[station.pk].sources[0] request = FakeRequest(POST={"seek": 1}) source_viewset.get_source = lambda x: source for action in ("sync", "skip", "restart", "seek"): func = getattr(source_viewset, action) func(request, 1) assert source.calls.get(action) class TestQueueSourceViewSet: @pytest.mark.django_db def test_get_sound_queryset(self, queue_source_viewset, station, sounds): ids = {sound.pk for sound in sounds} request = FakeRequest(station=station) query = queue_source_viewset.get_sound_queryset(request) assert set(query.values_list("pk", flat=True)) == ids @pytest.mark.django_db def test_push(self, queue_source_viewset, station, sounds): calls = {} sound = sounds[0] request = FakeRequest(station=station, data={"sound_id": sound.pk}) queue_source_viewset._run = lambda pk, func: calls.setdefault("_run", (pk, func)) result = queue_source_viewset.push(request, 13) assert "_run" in calls assert result[0] == 13 assert callable(result[1]) @pytest.mark.django_db def test_push_missing_sound_in_request_post(self, queue_source_viewset, station): request = FakeRequest(station=station, data={}) with pytest.raises(ValidationError): queue_source_viewset.push(request, 0)