write tests for serializers; add controllers.streamers + tests
This commit is contained in:
		@ -1,3 +1,4 @@
 | 
			
		||||
import itertools
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from datetime import datetime, time
 | 
			
		||||
@ -74,21 +75,53 @@ def station():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def station_ports(station):
 | 
			
		||||
    items = [
 | 
			
		||||
        models.Port(
 | 
			
		||||
            station=station,
 | 
			
		||||
            direction=models.Port.DIRECTION_INPUT,
 | 
			
		||||
            type=models.Port.TYPE_HTTP,
 | 
			
		||||
def stations(station):
 | 
			
		||||
    objs = [
 | 
			
		||||
        models.Station(
 | 
			
		||||
            name=f"test-{i}",
 | 
			
		||||
            slug=f"test-{i}",
 | 
			
		||||
            path=working_dir,
 | 
			
		||||
            default=(i == 0),
 | 
			
		||||
            active=True,
 | 
			
		||||
        ),
 | 
			
		||||
        models.Port(
 | 
			
		||||
            station=station,
 | 
			
		||||
            direction=models.Port.DIRECTION_OUTPUT,
 | 
			
		||||
            type=models.Port.TYPE_FILE,
 | 
			
		||||
            active=True,
 | 
			
		||||
        ),
 | 
			
		||||
        )
 | 
			
		||||
        for i in range(0, 3)
 | 
			
		||||
    ]
 | 
			
		||||
    models.Station.objects.bulk_create(objs)
 | 
			
		||||
    return [station] + objs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def station_ports(station):
 | 
			
		||||
    return _stations_ports(station)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def stations_ports(stations):
 | 
			
		||||
    return _stations_ports(*stations)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _stations_ports(*stations):
 | 
			
		||||
    items = list(
 | 
			
		||||
        itertools.chain(
 | 
			
		||||
            *[
 | 
			
		||||
                (
 | 
			
		||||
                    models.Port(
 | 
			
		||||
                        station=station,
 | 
			
		||||
                        direction=models.Port.DIRECTION_INPUT,
 | 
			
		||||
                        type=models.Port.TYPE_HTTP,
 | 
			
		||||
                        active=True,
 | 
			
		||||
                    ),
 | 
			
		||||
                    models.Port(
 | 
			
		||||
                        station=station,
 | 
			
		||||
                        direction=models.Port.DIRECTION_OUTPUT,
 | 
			
		||||
                        type=models.Port.TYPE_FILE,
 | 
			
		||||
                        active=True,
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
                for station in stations
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    models.Port.objects.bulk_create(items)
 | 
			
		||||
    return items
 | 
			
		||||
 | 
			
		||||
@ -180,3 +213,50 @@ def metadata_string(metadata_data):
 | 
			
		||||
        "\n".join(f"{key}={value}" for key, value in metadata_data.items())
 | 
			
		||||
        + "\nEND"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# -- streamers
 | 
			
		||||
class FakeStreamer(controllers.Streamer):
 | 
			
		||||
    calls = {}
 | 
			
		||||
 | 
			
		||||
    def fetch(self):
 | 
			
		||||
        self.calls["fetch"] = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeSource(controllers.Source):
 | 
			
		||||
    def __init__(self, id, *args, **kwargs):
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.args = args
 | 
			
		||||
        self.kwargs = kwargs
 | 
			
		||||
        self.calls = {}
 | 
			
		||||
 | 
			
		||||
    def fetch(self):
 | 
			
		||||
        self.calls["sync"] = True
 | 
			
		||||
 | 
			
		||||
    def sync(self):
 | 
			
		||||
        self.calls["sync"] = True
 | 
			
		||||
 | 
			
		||||
    def push(self, path):
 | 
			
		||||
        self.calls["push"] = path
 | 
			
		||||
        return path
 | 
			
		||||
 | 
			
		||||
    def skip(self):
 | 
			
		||||
        self.calls["skip"] = True
 | 
			
		||||
 | 
			
		||||
    def restart(self):
 | 
			
		||||
        self.calls["restart"] = True
 | 
			
		||||
 | 
			
		||||
    def seek(self, c):
 | 
			
		||||
        self.calls["seek"] = c
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def streamers(stations, stations_ports):
 | 
			
		||||
    streamers = controllers.Streamers(streamer_class=FakeStreamer)
 | 
			
		||||
    # avoid unecessary db calls
 | 
			
		||||
    streamers.streamers = {
 | 
			
		||||
        station.pk: FakeStreamer(station) for station in stations
 | 
			
		||||
    }
 | 
			
		||||
    for streamer in streamers.values():
 | 
			
		||||
        streamer.sources = [FakeSource(i) for i in range(0, 3)]
 | 
			
		||||
    return streamers
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										37
									
								
								aircox_streamer/tests/test_controllers_streamers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								aircox_streamer/tests/test_controllers_streamers.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
 | 
			
		||||
from django.utils import timezone as tz
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestStreamers:
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def test___init__(self, streamers):
 | 
			
		||||
        assert isinstance(streamers.timeout, timedelta)
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def test_reset(self, streamers, stations):
 | 
			
		||||
        streamers.reset()
 | 
			
		||||
        assert all(
 | 
			
		||||
            streamers.streamers[station.pk] == station for station in stations
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def test_fetch(self, streamers):
 | 
			
		||||
        streamers.next_date = tz.now() - tz.timedelta(seconds=30)
 | 
			
		||||
        streamers.streamers = None
 | 
			
		||||
 | 
			
		||||
        now = tz.now()
 | 
			
		||||
        streamers.fetch()
 | 
			
		||||
 | 
			
		||||
        assert all(streamer.calls.get("fetch") for streamer in streamers)
 | 
			
		||||
        assert streamers.next_date > now
 | 
			
		||||
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def test_fetch_timeout_not_reached(self, streamers):
 | 
			
		||||
        next_date = tz.now() + tz.timedelta(seconds=30)
 | 
			
		||||
        streamers.next_date = next_date
 | 
			
		||||
 | 
			
		||||
        streamers.fetch()
 | 
			
		||||
        assert all(not streamer.calls.get("fetch") for streamer in streamers)
 | 
			
		||||
        assert streamers.next_date == next_date
 | 
			
		||||
							
								
								
									
										185
									
								
								aircox_streamer/tests/test_viewsets.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								aircox_streamer/tests/test_viewsets.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,185 @@
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from django.http import Http404
 | 
			
		||||
 | 
			
		||||
from rest_framework.exceptions import ValidationError
 | 
			
		||||
from aircox_streamer.viewsets import (
 | 
			
		||||
    ControllerViewSet,
 | 
			
		||||
    SourceViewSet,
 | 
			
		||||
    StreamerViewSet,
 | 
			
		||||
    QueueSourceViewSet,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeSerializer:
 | 
			
		||||
    def __init__(self, instance, *args, **kwargs):
 | 
			
		||||
        self.instance = instance
 | 
			
		||||
        self.data = {"instance": self.instance}
 | 
			
		||||
        self.args = args
 | 
			
		||||
        self.kwargs = kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeRequest:
 | 
			
		||||
    def __init__(self, **kwargs):
 | 
			
		||||
        self.__dict__.update(**kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def controller_viewset(streamers, station):
 | 
			
		||||
    return ControllerViewSet(
 | 
			
		||||
        streamers=streamers,
 | 
			
		||||
        streamer=streamers[station.pk],
 | 
			
		||||
        serializer_class=FakeSerializer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def streamer_viewset(streamers, station):
 | 
			
		||||
    return StreamerViewSet(
 | 
			
		||||
        streamers=streamers,
 | 
			
		||||
        streamer=streamers[station.pk],
 | 
			
		||||
        serializer_class=FakeSerializer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def source_viewset(streamers, station):
 | 
			
		||||
    return SourceViewSet(
 | 
			
		||||
        streamers=streamers,
 | 
			
		||||
        streamer=streamers[station.pk],
 | 
			
		||||
        serializer_class=FakeSerializer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def queue_source_viewset(streamers, station):
 | 
			
		||||
    return QueueSourceViewSet(
 | 
			
		||||
        streamers=streamers,
 | 
			
		||||
        streamer=streamers[station.pk],
 | 
			
		||||
        serializer_class=FakeSerializer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestControllerViewSet:
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_get_streamer(self, controller_viewset, stations):
 | 
			
		||||
        station = stations[0]
 | 
			
		||||
        streamer = controller_viewset.get_streamer(station.pk)
 | 
			
		||||
        assert streamer.station.pk == station.pk
 | 
			
		||||
        assert streamer.calls.get("fetch")
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_get_streamer_station_not_found(self, controller_viewset):
 | 
			
		||||
        controller_viewset.streamers.streamers = {}
 | 
			
		||||
        with pytest.raises(Http404):
 | 
			
		||||
            controller_viewset.get_streamer(1)
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_get_serializer(self, controller_viewset):
 | 
			
		||||
        controller_viewset.object = {"object": "value"}
 | 
			
		||||
        serializer = controller_viewset.get_serializer(test=True)
 | 
			
		||||
        assert serializer.kwargs["test"]
 | 
			
		||||
        assert serializer.instance == controller_viewset.object
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_serialize(self, controller_viewset):
 | 
			
		||||
        instance = {}
 | 
			
		||||
        data = controller_viewset.serialize(instance, test=True)
 | 
			
		||||
        assert data == {"instance": instance}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestStreamerViewSet:
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_retrieve(self, streamer_viewset):
 | 
			
		||||
        streamer_viewset.streamer = {"streamer": "test"}
 | 
			
		||||
        resp = streamer_viewset.retrieve(None, None)
 | 
			
		||||
        assert resp.data == {"instance": streamer_viewset.streamer}
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_list(self, streamer_viewset):
 | 
			
		||||
        streamers = {"a": 1, "b": 2}
 | 
			
		||||
        streamer_viewset.streamers.streamers = streamers
 | 
			
		||||
        resp = streamer_viewset.list(None)
 | 
			
		||||
        assert set(resp.data["results"]["instance"]) == set(streamers.values())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSourceViewSet:
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_get_sources(self, source_viewset, streamers):
 | 
			
		||||
        source_viewset.streamer.sources.append(45)
 | 
			
		||||
        sources = source_viewset.get_sources()
 | 
			
		||||
        assert 45 not in set(sources)
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_get_source(self, source_viewset):
 | 
			
		||||
        source = source_viewset.get_source(1)
 | 
			
		||||
        assert source.id == 1
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_get_source_not_found(self, source_viewset):
 | 
			
		||||
        with pytest.raises(Http404):
 | 
			
		||||
            source_viewset.get_source(1000)
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_retrieve(self, source_viewset, station):
 | 
			
		||||
        resp = source_viewset.retrieve(None, 0)
 | 
			
		||||
        source = source_viewset.streamers[station.pk].sources[0]
 | 
			
		||||
        # this is FakeSerializer being used which provides us the proof
 | 
			
		||||
        assert resp.data["instance"] == source
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_list(self, source_viewset, station):
 | 
			
		||||
        sources = source_viewset.streamers[station.pk].sources
 | 
			
		||||
        resp = source_viewset.list(None)
 | 
			
		||||
        assert list(resp.data["results"]["instance"]) == sources
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test__run(self, source_viewset):
 | 
			
		||||
        calls = {}
 | 
			
		||||
 | 
			
		||||
        def action(x):
 | 
			
		||||
            return calls.setdefault("action", True)
 | 
			
		||||
 | 
			
		||||
        source_viewset._run(0, action)
 | 
			
		||||
        assert calls.get("action")
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_all_api_source_actions(self, source_viewset, station):
 | 
			
		||||
        source = source_viewset.streamers[station.pk].sources[0]
 | 
			
		||||
        request = FakeRequest(POST={"seek": 1})
 | 
			
		||||
        source_viewset.get_source = lambda x: source
 | 
			
		||||
 | 
			
		||||
        for action in ("sync", "skip", "restart", "seek"):
 | 
			
		||||
            func = getattr(source_viewset, action)
 | 
			
		||||
            func(request, 1)
 | 
			
		||||
            assert source.calls.get(action)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestQueueSourceViewSet:
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_get_sound_queryset(self, queue_source_viewset, station, sounds):
 | 
			
		||||
        ids = {sound.pk for sound in sounds}
 | 
			
		||||
        request = FakeRequest(station=station)
 | 
			
		||||
        query = queue_source_viewset.get_sound_queryset(request)
 | 
			
		||||
        assert set(query.values_list("pk", flat=True)) == ids
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_push(self, queue_source_viewset, station, sounds):
 | 
			
		||||
        calls = {}
 | 
			
		||||
        sound = sounds[0]
 | 
			
		||||
        request = FakeRequest(station=station, data={"sound_id": sound.pk})
 | 
			
		||||
        queue_source_viewset._run = lambda pk, func: calls.setdefault(
 | 
			
		||||
            "_run", (pk, func)
 | 
			
		||||
        )
 | 
			
		||||
        result = queue_source_viewset.push(request, 13)
 | 
			
		||||
        assert "_run" in calls
 | 
			
		||||
        assert result[0] == 13
 | 
			
		||||
        assert callable(result[1])
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.django_db
 | 
			
		||||
    def test_push_missing_sound_in_request_post(
 | 
			
		||||
        self, queue_source_viewset, station
 | 
			
		||||
    ):
 | 
			
		||||
        request = FakeRequest(station=station, data={})
 | 
			
		||||
        with pytest.raises(ValidationError):
 | 
			
		||||
            queue_source_viewset.push(request, 0)
 | 
			
		||||
		Reference in New Issue
	
	Block a user