commit 6bd35122cc096acff6cd4bd93cc06f2e5f8fd354 Author: Federico Ressi Date: Wed Oct 14 13:27:41 2020 +0200 Add typing annotation to selection tools Change-Id: I107fef91fc8d087d5470c2d8e141c7fe76930fb1 diff --git a/tobiko/common/_select.py b/tobiko/common/_select.py index 5697807..dce2088 100644 --- a/tobiko/common/_select.py +++ b/tobiko/common/_select.py @@ -13,10 +13,15 @@ # under the License. from __future__ import absolute_import +import typing # noqa + from tobiko import _exception -class Selection(list): +T = typing.TypeVar('T') + + +class Selection(list, typing.Generic[T]): def with_attributes(self, **attributes): return self.create( @@ -33,18 +38,18 @@ class Selection(list): return self.create(filter_by_items(self, exclude=True, **items)) @classmethod - def create(cls, objects): + def create(cls, objects: typing.Iterable[T]): return cls(objects) @property - def first(self): + def first(self) -> T: if self: return self[0] else: raise ObjectNotFound() @property - def unique(self): + def unique(self) -> T: if len(self) > 1: raise MultipleObjectsFound(list(self)) else: @@ -54,7 +59,7 @@ class Selection(list): return '{!s}({!r})'.format(type(self).__name__, list(self)) -def select(objects): +def select(objects: typing.Iterable[T]) -> Selection[T]: return Selection.create(objects) diff --git a/tobiko/openstack/tests/_nova.py b/tobiko/openstack/tests/_nova.py index 1485bae..78b91b7 100644 --- a/tobiko/openstack/tests/_nova.py +++ b/tobiko/openstack/tests/_nova.py @@ -58,14 +58,14 @@ def test_server_creation_and_shutoff(stack=TestServerCreationStack): def test_servers_creation(stack=TestServerCreationStack, number_of_servers=2) -> \ - typing.List[_nova.ServerStackFixture]: + tobiko.Selection[_nova.ServerStackFixture]: initial_servers_ids = {server.id for server in nova.list_servers()} pid = os.getpid() fixture_obj = tobiko.get_fixture_class(stack) # Get list of server stack instances - fixtures = tobiko.Selection( + fixtures: tobiko.Selection[_nova.ServerStackFixture] = tobiko.select( tobiko.get_fixture(fixture_obj, fixture_id=f'{pid}-{i}') for i in range(number_of_servers or 1)) diff --git a/tobiko/openstack/topology/_address.py b/tobiko/openstack/topology/_address.py index 3adeca0..7a81188 100644 --- a/tobiko/openstack/topology/_address.py +++ b/tobiko/openstack/topology/_address.py @@ -32,7 +32,7 @@ def list_addresses(obj, ip_version: typing.Optional[int] = None, port: typing.Union[int, str, None] = None, ssh_config: bool = False) -> \ - typing.List[netaddr.IPAddress]: + tobiko.Selection[netaddr.IPAddress]: if isinstance(obj, tobiko.Selection): addresses = obj elif isinstance(obj, netaddr.IPAddress): @@ -58,7 +58,7 @@ def list_host_addresses(host: str, ip_version: typing.Optional[int] = None, port: typing.Union[int, str, None] = None, ssh_config: bool = False) -> \ - typing.List[netaddr.IPAddress]: + tobiko.Selection[netaddr.IPAddress]: if not port: if ssh_config: @@ -66,7 +66,7 @@ def list_host_addresses(host: str, else: port = 0 - addresses = [] + addresses: tobiko.Selection[netaddr.IPAddress] = tobiko.Selection() hosts = [host] resolved = set() while hosts: diff --git a/tobiko/openstack/topology/_topology.py b/tobiko/openstack/topology/_topology.py index 639f0be..2a4c955 100644 --- a/tobiko/openstack/topology/_topology.py +++ b/tobiko/openstack/topology/_topology.py @@ -323,18 +323,19 @@ class OpenStackTopology(tobiko.SharedFixture): self._groups[group] = nodes = self.create_group() return nodes - def create_group(self) -> tobiko.Selection: + @staticmethod + def create_group() -> tobiko.Selection[OpenStackTopologyNode]: return tobiko.Selection() - def get_group(self, group) -> tobiko.Selection: + def get_group(self, group) -> tobiko.Selection[OpenStackTopologyNode]: try: return self._groups[group] except KeyError as ex: raise _exception.NoSuchOpenStackTopologyNodeGroup( group=group) from ex - def get_groups(self, groups) -> tobiko.Selection: - nodes = tobiko.Selection() + def get_groups(self, groups) -> tobiko.Selection[OpenStackTopologyNode]: + nodes: tobiko.Selection[OpenStackTopologyNode] = tobiko.Selection() for group in groups: nodes.extend(self.get_group(group)) return nodes diff --git a/tobiko/shell/ip.py b/tobiko/shell/ip.py index e54bb3e..2dd7df5 100644 --- a/tobiko/shell/ip.py +++ b/tobiko/shell/ip.py @@ -41,7 +41,7 @@ INETS = { def list_ip_addresses(ip_version: typing.Optional[int] = None, scope: str = None, **execute_params) -> \ - typing.List[netaddr.IPAddress]: + tobiko.Selection[netaddr.IPAddress]: inets = INETS.get(ip_version) if inets is None: error = "invalid IP version: {!r}".format(ip_version) @@ -49,7 +49,7 @@ def list_ip_addresses(ip_version: typing.Optional[int] = None, output = execute_ip(['-o', 'address', 'list'], **execute_params) - ips = tobiko.Selection() + ips: tobiko.Selection[netaddr.IPAddress] = tobiko.Selection() if output: for line in output.splitlines(): fields = line.strip().split()