from __future__ import annotations

import logging
import typing

from zigpy.const import (  # noqa: F401
    SIG_ENDPOINTS,
    SIG_EP_INPUT,
    SIG_EP_OUTPUT,
    SIG_EP_PROFILE,
    SIG_EP_TYPE,
    SIG_MANUFACTURER,
    SIG_MODEL,
    SIG_MODELS_INFO,
    SIG_NODE_DESC,
    SIG_SKIP_CONFIG,
)
import zigpy.device
import zigpy.endpoint
from zigpy.quirks.registry import DeviceRegistry  # noqa: F401
import zigpy.types as t
from zigpy.types.basic import uint16_t
import zigpy.zcl
import zigpy.zcl.foundation as foundation

if typing.TYPE_CHECKING:
    from zigpy.application import ControllerApplication

_LOGGER = logging.getLogger(__name__)

_DEVICE_REGISTRY = DeviceRegistry()
_uninitialized_device_message_handlers = []


def get_device(
    device: zigpy.device.Device, registry: DeviceRegistry | None = None
) -> zigpy.device.Device:
    """Get a CustomDevice object, if one is available"""
    if registry is None:
        return _DEVICE_REGISTRY.get_device(device)

    return registry.get_device(device)


def get_quirk_list(
    manufacturer: str, model: str, registry: DeviceRegistry | None = None
):
    """Get the Quirk list for a given manufacturer and model."""
    if registry is None:
        return _DEVICE_REGISTRY.registry[manufacturer][model]

    return registry.registry[manufacturer][model]


def register_uninitialized_device_message_handler(handler: typing.Callable) -> None:
    """Register an handler for messages received by uninitialized devices.

    each handler is passed same parameters as
    zigpy.application.ControllerApplication.handle_message
    """
    if handler not in _uninitialized_device_message_handlers:
        _uninitialized_device_message_handlers.append(handler)


class CustomDevice(zigpy.device.Device):
    replacement: dict[str, typing.Any] = {}
    signature = None

    def __init_subclass__(cls) -> None:
        if getattr(cls, "signature", None) is not None:
            _DEVICE_REGISTRY.add_to_registry(cls)

    def __init__(
        self,
        application: ControllerApplication,
        ieee: t.EUI64,
        nwk: t.NWK,
        replaces: zigpy.device.Device,
    ) -> None:
        super().__init__(application, ieee, nwk)

        def set_device_attr(attr):
            if attr in self.replacement:
                setattr(self, attr, self.replacement[attr])
            else:
                setattr(self, attr, getattr(replaces, attr))

        for attr in ("lqi", "rssi", "last_seen", "relays"):
            setattr(self, attr, getattr(replaces, attr))

        set_device_attr("status")
        set_device_attr(SIG_NODE_DESC)
        set_device_attr(SIG_MANUFACTURER)
        set_device_attr(SIG_MODEL)
        set_device_attr(SIG_SKIP_CONFIG)
        for endpoint_id, _endpoint in self.replacement.get(SIG_ENDPOINTS, {}).items():
            self.add_endpoint(endpoint_id, replace_device=replaces)

    def add_endpoint(
        self, endpoint_id: int, replace_device: zigpy.device.Device | None = None
    ) -> zigpy.endpoint.Endpoint:
        if endpoint_id not in self.replacement.get(SIG_ENDPOINTS, {}):
            return super().add_endpoint(endpoint_id)

        endpoints = self.replacement[SIG_ENDPOINTS]

        if isinstance(endpoints[endpoint_id], tuple):
            custom_ep_type = endpoints[endpoint_id][0]
            replacement_data = endpoints[endpoint_id][1]
        else:
            custom_ep_type = CustomEndpoint
            replacement_data = endpoints[endpoint_id]

        ep = custom_ep_type(self, endpoint_id, replacement_data, replace_device)
        self.endpoints[endpoint_id] = ep
        return ep


class CustomEndpoint(zigpy.endpoint.Endpoint):
    def __init__(
        self,
        device: CustomDevice,
        endpoint_id: int,
        replacement_data: dict[str, typing.Any],
        replace_device: zigpy.device.Device,
    ) -> None:
        super().__init__(device, endpoint_id)

        def set_device_attr(attr):
            if attr in replacement_data:
                setattr(self, attr, replacement_data[attr])
            else:
                setattr(self, attr, getattr(replace_device[endpoint_id], attr))

        set_device_attr(SIG_EP_PROFILE)
        set_device_attr(SIG_EP_TYPE)
        self.status = zigpy.endpoint.Status.ZDO_INIT

        for c in replacement_data.get(SIG_EP_INPUT, []):
            if isinstance(c, int):
                cluster = None
                cluster_id = c
            else:
                cluster = c(self, is_server=True)
                cluster_id = cluster.cluster_id
            self.add_input_cluster(cluster_id, cluster)

        for c in replacement_data.get(SIG_EP_OUTPUT, []):
            if isinstance(c, int):
                cluster = None
                cluster_id = c
            else:
                cluster = c(self, is_server=False)
                cluster_id = cluster.cluster_id
            self.add_output_cluster(cluster_id, cluster)


class CustomCluster(zigpy.zcl.Cluster):
    _skip_registry = True
    _CONSTANT_ATTRIBUTES: dict[int, typing.Any] | None = None

    manufacturer_id_override: t.uint16_t | None = None

    @property
    def _is_manuf_specific(self) -> bool:
        """Return True if cluster_id is within manufacturer specific range."""
        return 0xFC00 <= self.cluster_id <= 0xFFFF

    def _has_manuf_attr(self, attrs_to_process: typing.Iterable | list | dict) -> bool:
        """Return True if contains a manufacturer specific attribute."""
        if self._is_manuf_specific:
            return True

        for attr_id in attrs_to_process:
            if (
                attr_id in self.attributes
                and self.attributes[attr_id].is_manufacturer_specific
            ):
                return True

        return False

    async def command(
        self,
        command_id: foundation.GeneralCommand | int | t.uint8_t,
        *args,
        manufacturer: int | t.uint16_t | None = None,
        expect_reply: bool = True,
        tsn: int | t.uint8_t | None = None,
        **kwargs: typing.Any,
    ) -> typing.Coroutine:
        command = self.server_commands[command_id]

        if manufacturer is None and (
            self._is_manuf_specific or command.is_manufacturer_specific
        ):
            manufacturer = self.endpoint.manufacturer_id

        return await self.request(
            False,
            command.id,
            command.schema,
            *args,
            manufacturer=manufacturer,
            expect_reply=expect_reply,
            tsn=tsn,
            **kwargs,
        )

    async def client_command(
        self,
        command_id: foundation.GeneralCommand | int | t.uint8_t,
        *args,
        manufacturer: int | t.uint16_t | None = None,
        tsn: int | t.uint8_t | None = None,
        **kwargs: typing.Any,
    ):
        command = self.client_commands[command_id]

        if manufacturer is None and (
            self._is_manuf_specific or command.is_manufacturer_specific
        ):
            manufacturer = self.endpoint.manufacturer_id

        return await self.reply(
            False,
            command.id,
            command.schema,
            *args,
            manufacturer=manufacturer,
            tsn=tsn,
            **kwargs,
        )

    async def read_attributes_raw(
        self, attributes: list[uint16_t], manufacturer: uint16_t | None = None
    ):
        if not self._CONSTANT_ATTRIBUTES:
            return await super().read_attributes_raw(
                attributes, manufacturer=manufacturer
            )

        succeeded = [
            foundation.ReadAttributeRecord(
                attr, foundation.Status.SUCCESS, foundation.TypeValue()
            )
            for attr in attributes
            if attr in self._CONSTANT_ATTRIBUTES
        ]
        for record in succeeded:
            record.value.value = self._CONSTANT_ATTRIBUTES[record.attrid]

        attrs_to_read = [
            attr for attr in attributes if attr not in self._CONSTANT_ATTRIBUTES
        ]

        if not attrs_to_read:
            return [succeeded]

        results = await super().read_attributes_raw(
            attrs_to_read, manufacturer=manufacturer
        )
        if not isinstance(results[0], list):
            for attrid in attrs_to_read:
                succeeded.append(
                    foundation.ReadAttributeRecord(
                        attrid,
                        results[0],
                        foundation.TypeValue(),
                    )
                )
        else:
            succeeded.extend(results[0])
        return [succeeded]

    async def _configure_reporting(  # type:ignore[override]
        self,
        config_records: list[foundation.AttributeReportingConfig],
        *args,
        manufacturer: int | t.uint16_t | None = None,
        **kwargs,
    ):
        """Configure reporting ZCL foundation command."""
        if manufacturer is None and self._has_manuf_attr(
            [a.attrid for a in config_records]
        ):
            manufacturer = self.endpoint.manufacturer_id
        return await super()._configure_reporting(
            config_records,
            *args,
            manufacturer=manufacturer,
            **kwargs,
        )

    async def _read_attributes(  # type:ignore[override]
        self,
        attribute_ids: list[t.uint16_t],
        *args,
        manufacturer: int | t.uint16_t | None = None,
        **kwargs,
    ):
        """Read attributes ZCL foundation command."""
        if manufacturer is None and self._has_manuf_attr(attribute_ids):
            manufacturer = self.endpoint.manufacturer_id
        return await super()._read_attributes(
            attribute_ids, *args, manufacturer=manufacturer, **kwargs
        )

    async def _write_attributes(  # type:ignore[override]
        self,
        attributes: list[foundation.Attribute],
        *args,
        manufacturer: int | t.uint16_t | None = None,
        **kwargs,
    ):
        """Write attribute ZCL foundation command."""
        if manufacturer is None and self._has_manuf_attr(
            [a.attrid for a in attributes]
        ):
            manufacturer = self.endpoint.manufacturer_id
        return await super()._write_attributes(
            attributes, *args, manufacturer=manufacturer, **kwargs
        )

    async def _write_attributes_undivided(  # type:ignore[override]
        self,
        attributes: list[foundation.Attribute],
        *args,
        manufacturer: int | t.uint16_t | None = None,
        **kwargs,
    ):
        """Write attribute undivided ZCL foundation command."""
        if manufacturer is None and self._has_manuf_attr(
            [a.attrid for a in attributes]
        ):
            manufacturer = self.endpoint.manufacturer_id
        return await super()._write_attributes_undivided(
            attributes, *args, manufacturer=manufacturer, **kwargs
        )

    def get(self, key: int | str, default: typing.Any | None = None) -> typing.Any:
        """Get cached attribute."""

        try:
            attr_def = self.find_attribute(key)
        except KeyError:
            return super().get(key, default)

        # Ensure we check the constant attributes dictionary first, since their values
        # will not be in the attribute cache but can be read immediately.
        if (
            self._CONSTANT_ATTRIBUTES is not None
            and attr_def.id in self._CONSTANT_ATTRIBUTES
        ):
            return self._CONSTANT_ATTRIBUTES[attr_def.id]

        return super().get(key, default)


def handle_message_from_uninitialized_sender(
    sender: zigpy.device.Device,
    profile: int,
    cluster: int,
    src_ep: int,
    dst_ep: int,
    message: bytes,
) -> None:
    """Processes message from an uninitialized sender."""
    for handler in _uninitialized_device_message_handlers:
        if handler(sender, profile, cluster, src_ep, dst_ep, message):
            break
