write tests for serializers; add controllers.streamers + tests
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
from datetime import datetime, time
|
||||
@ -74,21 +75,53 @@ def station():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def station_ports(station):
|
||||
items = [
|
||||
models.Port(
|
||||
station=station,
|
||||
direction=models.Port.DIRECTION_INPUT,
|
||||
type=models.Port.TYPE_HTTP,
|
||||
def stations(station):
|
||||
objs = [
|
||||
models.Station(
|
||||
name=f"test-{i}",
|
||||
slug=f"test-{i}",
|
||||
path=working_dir,
|
||||
default=(i == 0),
|
||||
active=True,
|
||||
),
|
||||
models.Port(
|
||||
station=station,
|
||||
direction=models.Port.DIRECTION_OUTPUT,
|
||||
type=models.Port.TYPE_FILE,
|
||||
active=True,
|
||||
),
|
||||
)
|
||||
for i in range(0, 3)
|
||||
]
|
||||
models.Station.objects.bulk_create(objs)
|
||||
return [station] + objs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def station_ports(station):
|
||||
return _stations_ports(station)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stations_ports(stations):
|
||||
return _stations_ports(*stations)
|
||||
|
||||
|
||||
def _stations_ports(*stations):
|
||||
items = list(
|
||||
itertools.chain(
|
||||
*[
|
||||
(
|
||||
models.Port(
|
||||
station=station,
|
||||
direction=models.Port.DIRECTION_INPUT,
|
||||
type=models.Port.TYPE_HTTP,
|
||||
active=True,
|
||||
),
|
||||
models.Port(
|
||||
station=station,
|
||||
direction=models.Port.DIRECTION_OUTPUT,
|
||||
type=models.Port.TYPE_FILE,
|
||||
active=True,
|
||||
),
|
||||
)
|
||||
for station in stations
|
||||
]
|
||||
)
|
||||
)
|
||||
models.Port.objects.bulk_create(items)
|
||||
return items
|
||||
|
||||
@ -180,3 +213,50 @@ def metadata_string(metadata_data):
|
||||
"\n".join(f"{key}={value}" for key, value in metadata_data.items())
|
||||
+ "\nEND"
|
||||
)
|
||||
|
||||
|
||||
# -- streamers
|
||||
class FakeStreamer(controllers.Streamer):
|
||||
calls = {}
|
||||
|
||||
def fetch(self):
|
||||
self.calls["fetch"] = True
|
||||
|
||||
|
||||
class FakeSource(controllers.Source):
|
||||
def __init__(self, id, *args, **kwargs):
|
||||
self.id = id
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.calls = {}
|
||||
|
||||
def fetch(self):
|
||||
self.calls["sync"] = True
|
||||
|
||||
def sync(self):
|
||||
self.calls["sync"] = True
|
||||
|
||||
def push(self, path):
|
||||
self.calls["push"] = path
|
||||
return path
|
||||
|
||||
def skip(self):
|
||||
self.calls["skip"] = True
|
||||
|
||||
def restart(self):
|
||||
self.calls["restart"] = True
|
||||
|
||||
def seek(self, c):
|
||||
self.calls["seek"] = c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streamers(stations, stations_ports):
|
||||
streamers = controllers.Streamers(streamer_class=FakeStreamer)
|
||||
# avoid unecessary db calls
|
||||
streamers.streamers = {
|
||||
station.pk: FakeStreamer(station) for station in stations
|
||||
}
|
||||
for streamer in streamers.values():
|
||||
streamer.sources = [FakeSource(i) for i in range(0, 3)]
|
||||
return streamers
|
||||
|
37
aircox_streamer/tests/test_controllers_streamers.py
Normal file
37
aircox_streamer/tests/test_controllers_streamers.py
Normal file
@ -0,0 +1,37 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from django.utils import timezone as tz
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStreamers:
|
||||
@pytest.fixture
|
||||
def test___init__(self, streamers):
|
||||
assert isinstance(streamers.timeout, timedelta)
|
||||
|
||||
@pytest.fixture
|
||||
def test_reset(self, streamers, stations):
|
||||
streamers.reset()
|
||||
assert all(
|
||||
streamers.streamers[station.pk] == station for station in stations
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def test_fetch(self, streamers):
|
||||
streamers.next_date = tz.now() - tz.timedelta(seconds=30)
|
||||
streamers.streamers = None
|
||||
|
||||
now = tz.now()
|
||||
streamers.fetch()
|
||||
|
||||
assert all(streamer.calls.get("fetch") for streamer in streamers)
|
||||
assert streamers.next_date > now
|
||||
|
||||
@pytest.fixture
|
||||
def test_fetch_timeout_not_reached(self, streamers):
|
||||
next_date = tz.now() + tz.timedelta(seconds=30)
|
||||
streamers.next_date = next_date
|
||||
|
||||
streamers.fetch()
|
||||
assert all(not streamer.calls.get("fetch") for streamer in streamers)
|
||||
assert streamers.next_date == next_date
|
185
aircox_streamer/tests/test_viewsets.py
Normal file
185
aircox_streamer/tests/test_viewsets.py
Normal file
@ -0,0 +1,185 @@
|
||||
import pytest
|
||||
|
||||
from django.http import Http404
|
||||
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from aircox_streamer.viewsets import (
|
||||
ControllerViewSet,
|
||||
SourceViewSet,
|
||||
StreamerViewSet,
|
||||
QueueSourceViewSet,
|
||||
)
|
||||
|
||||
|
||||
class FakeSerializer:
|
||||
def __init__(self, instance, *args, **kwargs):
|
||||
self.instance = instance
|
||||
self.data = {"instance": self.instance}
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(**kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_viewset(streamers, station):
|
||||
return ControllerViewSet(
|
||||
streamers=streamers,
|
||||
streamer=streamers[station.pk],
|
||||
serializer_class=FakeSerializer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streamer_viewset(streamers, station):
|
||||
return StreamerViewSet(
|
||||
streamers=streamers,
|
||||
streamer=streamers[station.pk],
|
||||
serializer_class=FakeSerializer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def source_viewset(streamers, station):
|
||||
return SourceViewSet(
|
||||
streamers=streamers,
|
||||
streamer=streamers[station.pk],
|
||||
serializer_class=FakeSerializer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_source_viewset(streamers, station):
|
||||
return QueueSourceViewSet(
|
||||
streamers=streamers,
|
||||
streamer=streamers[station.pk],
|
||||
serializer_class=FakeSerializer,
|
||||
)
|
||||
|
||||
|
||||
class TestControllerViewSet:
|
||||
@pytest.mark.django_db
|
||||
def test_get_streamer(self, controller_viewset, stations):
|
||||
station = stations[0]
|
||||
streamer = controller_viewset.get_streamer(station.pk)
|
||||
assert streamer.station.pk == station.pk
|
||||
assert streamer.calls.get("fetch")
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_get_streamer_station_not_found(self, controller_viewset):
|
||||
controller_viewset.streamers.streamers = {}
|
||||
with pytest.raises(Http404):
|
||||
controller_viewset.get_streamer(1)
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_get_serializer(self, controller_viewset):
|
||||
controller_viewset.object = {"object": "value"}
|
||||
serializer = controller_viewset.get_serializer(test=True)
|
||||
assert serializer.kwargs["test"]
|
||||
assert serializer.instance == controller_viewset.object
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_serialize(self, controller_viewset):
|
||||
instance = {}
|
||||
data = controller_viewset.serialize(instance, test=True)
|
||||
assert data == {"instance": instance}
|
||||
|
||||
|
||||
class TestStreamerViewSet:
|
||||
@pytest.mark.django_db
|
||||
def test_retrieve(self, streamer_viewset):
|
||||
streamer_viewset.streamer = {"streamer": "test"}
|
||||
resp = streamer_viewset.retrieve(None, None)
|
||||
assert resp.data == {"instance": streamer_viewset.streamer}
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_list(self, streamer_viewset):
|
||||
streamers = {"a": 1, "b": 2}
|
||||
streamer_viewset.streamers.streamers = streamers
|
||||
resp = streamer_viewset.list(None)
|
||||
assert set(resp.data["results"]["instance"]) == set(streamers.values())
|
||||
|
||||
|
||||
class TestSourceViewSet:
|
||||
@pytest.mark.django_db
|
||||
def test_get_sources(self, source_viewset, streamers):
|
||||
source_viewset.streamer.sources.append(45)
|
||||
sources = source_viewset.get_sources()
|
||||
assert 45 not in set(sources)
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_get_source(self, source_viewset):
|
||||
source = source_viewset.get_source(1)
|
||||
assert source.id == 1
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_get_source_not_found(self, source_viewset):
|
||||
with pytest.raises(Http404):
|
||||
source_viewset.get_source(1000)
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_retrieve(self, source_viewset, station):
|
||||
resp = source_viewset.retrieve(None, 0)
|
||||
source = source_viewset.streamers[station.pk].sources[0]
|
||||
# this is FakeSerializer being used which provides us the proof
|
||||
assert resp.data["instance"] == source
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_list(self, source_viewset, station):
|
||||
sources = source_viewset.streamers[station.pk].sources
|
||||
resp = source_viewset.list(None)
|
||||
assert list(resp.data["results"]["instance"]) == sources
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test__run(self, source_viewset):
|
||||
calls = {}
|
||||
|
||||
def action(x):
|
||||
return calls.setdefault("action", True)
|
||||
|
||||
source_viewset._run(0, action)
|
||||
assert calls.get("action")
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_all_api_source_actions(self, source_viewset, station):
|
||||
source = source_viewset.streamers[station.pk].sources[0]
|
||||
request = FakeRequest(POST={"seek": 1})
|
||||
source_viewset.get_source = lambda x: source
|
||||
|
||||
for action in ("sync", "skip", "restart", "seek"):
|
||||
func = getattr(source_viewset, action)
|
||||
func(request, 1)
|
||||
assert source.calls.get(action)
|
||||
|
||||
|
||||
class TestQueueSourceViewSet:
|
||||
@pytest.mark.django_db
|
||||
def test_get_sound_queryset(self, queue_source_viewset, station, sounds):
|
||||
ids = {sound.pk for sound in sounds}
|
||||
request = FakeRequest(station=station)
|
||||
query = queue_source_viewset.get_sound_queryset(request)
|
||||
assert set(query.values_list("pk", flat=True)) == ids
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_push(self, queue_source_viewset, station, sounds):
|
||||
calls = {}
|
||||
sound = sounds[0]
|
||||
request = FakeRequest(station=station, data={"sound_id": sound.pk})
|
||||
queue_source_viewset._run = lambda pk, func: calls.setdefault(
|
||||
"_run", (pk, func)
|
||||
)
|
||||
result = queue_source_viewset.push(request, 13)
|
||||
assert "_run" in calls
|
||||
assert result[0] == 13
|
||||
assert callable(result[1])
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_push_missing_sound_in_request_post(
|
||||
self, queue_source_viewset, station
|
||||
):
|
||||
request = FakeRequest(station=station, data={})
|
||||
with pytest.raises(ValidationError):
|
||||
queue_source_viewset.push(request, 0)
|
Reference in New Issue
Block a user