diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e8a93b8dc..e66ab34435 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ These changes are available on the `master` branch, but have not yet been releas ### Added +- Added `replace_item` to `DesignerView`, `Section`, `Container`, `ActionRow`, & + `MediaGallery` ([#3032](https://github.com/Pycord-Development/pycord/pull/3032)) - Added `.extension` attribute to emojis to get their file extension. ([#3055](https://github.com/Pycord-Development/pycord/pull/3055)) @@ -22,6 +24,8 @@ These changes are available on the `master` branch, but have not yet been releas ### Fixed +- Fixed core issues with modifying items in `Container` and `Section` + ([#3032](https://github.com/Pycord-Development/pycord/pull/3032)) - Fixed `RawMessageUpdateEvent.cached_message` being always `None` even when the message was cached. ([#3038](https://github.com/Pycord-Development/pycord/pull/3038)) - Fixed downloading animated emojis which were originally uploaded as WebP files by diff --git a/discord/client.py b/discord/client.py index 749caedd55..d842487d0a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -594,7 +594,7 @@ async def on_modal_error(self, error: Exception, interaction: Interaction) -> No The default modal error handler provided by the client. The default implementation prints the traceback to stderr. - This only fires for a modal if you did not define its :func:`~discord.ui.Modal.on_error`. + This only fires for a modal if you did not define its :func:`~discord.ui.BaseModal.on_error`. Parameters ---------- diff --git a/discord/colour.py b/discord/colour.py index 0ec77d5786..fc83c04915 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -118,6 +118,18 @@ def to_rgb(self) -> tuple[int, int, int]: """Returns an (r, g, b) tuple representing the colour.""" return self.r, self.g, self.b + @classmethod + def resolve_value(cls: type[CT], value: int | Colour | None) -> CT: + if value is None or isinstance(value, Colour): + return value + elif isinstance(value, int): + return cls(value=value) + else: + raise TypeError( + "Expected discord.Colour, int, or None but received" + f" {value.__class__.__name__} instead." + ) + @classmethod def from_rgb(cls: type[CT], r: int, g: int, b: int) -> CT: """Constructs a :class:`Colour` from an RGB tuple.""" diff --git a/discord/components.py b/discord/components.py index 9775c97f01..b0672c0296 100644 --- a/discord/components.py +++ b/discord/components.py @@ -134,7 +134,7 @@ def _raw_construct(cls: type[C], **kwargs) -> C: try: value = kwargs[slot] except KeyError: - pass + setattr(self, slot, None) else: setattr(self, slot, value) return self diff --git a/discord/embeds.py b/discord/embeds.py index 81424050e2..77f0c70370 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -522,15 +522,7 @@ def colour(self) -> Colour | None: @colour.setter def colour(self, value: int | Colour | None): # type: ignore - if value is None or isinstance(value, Colour): - self._colour = value - elif isinstance(value, int): - self._colour = Colour(value=value) - else: - raise TypeError( - "Expected discord.Colour, int, or None but received" - f" {value.__class__.__name__} instead." - ) + self._colour = Colour.resolve_value(value) color = colour diff --git a/discord/ext/pages/pagination.py b/discord/ext/pages/pagination.py index 08f881baa2..ed7ba65740 100644 --- a/discord/ext/pages/pagination.py +++ b/discord/ext/pages/pagination.py @@ -26,6 +26,8 @@ from typing import List +from typing_extensions import Self + import discord from discord.errors import DiscordException from discord.ext.bridge import BridgeContext @@ -911,6 +913,12 @@ def update_custom_view(self, custom_view: discord.ui.View): for item in custom_view.children: self.add_item(item) + def clear_items(self) -> Self: + # Necessary override due to behavior of Item.parent, see #3057 + self.children.clear() + self._View__weights.clear() + return self + def get_page_group_content(self, page_group: PageGroup) -> list[Page]: """Returns a converted list of `Page` objects for the given page group based on the content of its pages.""" return [self.get_page_content(page) for page in page_group.pages] diff --git a/discord/interactions.py b/discord/interactions.py index 47498c6946..04eaadad0f 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -88,7 +88,7 @@ from .types.interactions import InteractionCallbackResponse, InteractionData from .types.interactions import InteractionMetadata as InteractionMetadataPayload from .types.interactions import MessageInteraction as MessageInteractionPayload - from .ui.modal import Modal + from .ui.modal import BaseModal from .ui.view import BaseView InteractionChannel = Union[ @@ -168,7 +168,7 @@ class Interaction: The view that this interaction belongs to. .. versionadded:: 2.7 - modal: Optional[:class:`Modal`] + modal: Optional[:class:`BaseModal`] The modal that this interaction belongs to. .. versionadded:: 2.7 @@ -258,7 +258,7 @@ def _from_data(self, data: InteractionPayload): self.command: ApplicationCommand | None = None self.view: BaseView | None = None - self.modal: Modal | None = None + self.modal: BaseModal | None = None self.attachment_size_limit: int = data.get("attachment_size_limit") self.message: Message | None = None @@ -1343,14 +1343,14 @@ async def send_autocomplete_result( self._responded = True await self._process_callback_response(callback_response) - async def send_modal(self, modal: Modal) -> Interaction: + async def send_modal(self, modal: BaseModal) -> Interaction: """|coro| Responds to this interaction by sending a modal dialog. This cannot be used to respond to another modal dialog submission. Parameters ---------- - modal: :class:`discord.ui.Modal` + modal: :class:`discord.ui.BaseModal` The modal dialog to display to the user. Raises diff --git a/discord/state.py b/discord/state.py index c7a29b437e..8ec51204ab 100644 --- a/discord/state.py +++ b/discord/state.py @@ -70,7 +70,7 @@ from .stage_instance import StageInstance from .sticker import GuildSticker from .threads import Thread, ThreadMember -from .ui.modal import Modal, ModalStore +from .ui.modal import BaseModal, ModalStore from .ui.view import BaseView, ViewStore from .user import ClientUser, User @@ -413,7 +413,7 @@ def store_view(self, view: BaseView, message_id: int | None = None) -> None: def purge_message_view(self, message_id: int) -> None: self._view_store.remove_message_view(message_id) - def store_modal(self, modal: Modal, message_id: int) -> None: + def store_modal(self, modal: BaseModal, message_id: int) -> None: self._modal_store.add_modal(modal, message_id) def prevent_view_updates_for(self, message_id: int) -> BaseView | None: diff --git a/discord/ui/action_row.py b/discord/ui/action_row.py index 6737a1625b..632be4faf8 100644 --- a/discord/ui/action_row.py +++ b/discord/ui/action_row.py @@ -95,11 +95,7 @@ def __init__( self.children: list[ViewItem] = [] - self._underlying = ActionRowComponent._raw_construct( - type=ComponentType.action_row, - id=id, - children=[], - ) + self._underlying = self._generate_underlying(id=id) for func in self.__row_children_items__: item: ViewItem = func.__discord_ui_model_type__( @@ -111,14 +107,33 @@ def __init__( for i in items: self.add_item(i) + @property + def items(self) -> list[ViewItem]: + return self.children + + @items.setter + def items(self, value: list[ViewItem]) -> None: + self.children = value + def _add_component_from_item(self, item: ViewItem): - self._underlying.children.append(item._underlying) + self.underlying.children.append(item._generate_underlying()) def _set_components(self, items: list[ViewItem]): - self._underlying.children.clear() + self.underlying.children.clear() for item in items: self._add_component_from_item(item) + def _generate_underlying(self, id: int | None = None) -> ActionRowComponent: + super()._generate_underlying(ActionRowComponent) + row = ActionRowComponent._raw_construct( + type=ComponentType.action_row, + id=id or self.id, + children=[], + ) + for i in self.children: + row.children.append(i._generate_underlying()) + return row + def add_item(self, item: ViewItem) -> Self: """Adds an item to the action row. @@ -164,6 +179,36 @@ def remove_item(self, item: ViewItem | str | int) -> Self: item.parent = None return self + def replace_item( + self, original_item: ViewItem | str | int, new_item: ViewItem + ) -> Self: + """Directly replace an item in this row. + If an :class:`int` is provided, the item will be replaced by ``id``, otherwise by ``custom_id``. + + Parameters + ---------- + original_item: Union[:class:`ViewItem`, :class:`int`, :class:`str`] + The item, item ``id``, or item ``custom_id`` to replace in the row. + new_item: :class:`ViewItem` + The new item to insert into the row. + """ + + if not isinstance(new_item, (Select, Button)): + raise TypeError(f"expected Select or Button, not {new_item.__class__!r}") + + if isinstance(original_item, (str, int)): + original_item = self.get_item(original_item) + if not original_item: + raise ValueError(f"Could not find original_item in row.") + try: + i = self.children.index(original_item) + new_item.parent = self + self.children[i] = new_item + original_item.parent = None + except ValueError: + raise ValueError(f"Could not find original_item in row.") + return self + def get_item(self, id: str | int) -> ViewItem | None: """Get an item from this action row. Roughly equivalent to `utils.get(row.children, ...)`. If an ``int`` is provided, the item will be retrieved by ``id``, otherwise by ``custom_id``. @@ -296,7 +341,7 @@ def add_select( id: int | None = None, default_values: Sequence[SelectDefaultValue] | None = None, ) -> Self: - """Adds a :class:`Select` to the container. + """Adds a :class:`Select` to the action row. To append a pre-existing :class:`Select`, use the :meth:`add_item` method instead. @@ -358,7 +403,7 @@ def is_persistent(self) -> bool: return all(item.is_persistent() for item in self.children) def refresh_component(self, component: ActionRowComponent) -> None: - self._underlying = component + self.underlying = component for i, y in enumerate(component.components): x = self.children[i] x.refresh_component(y) @@ -396,14 +441,14 @@ def width(self): """Return the sum of the items' widths.""" t = 0 for item in self.children: - t += 1 if item._underlying.type is ComponentType.button else 5 + t += 1 if item.underlying.type is ComponentType.button else 5 return t def walk_items(self) -> Iterator[ViewItem]: yield from self.children def to_component_dict(self) -> ActionRowPayload: - self._set_components(self.children) + self._underlying = self._generate_underlying() return super().to_component_dict() @classmethod diff --git a/discord/ui/button.py b/discord/ui/button.py index 8f1df340dd..fadfdb5ef1 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -147,8 +147,8 @@ def __init__( f" {emoji.__class__}" ) - self._underlying = ButtonComponent._raw_construct( - type=ComponentType.button, + self.row = row + self._underlying = self._generate_underlying( custom_id=custom_id, url=url, disabled=disabled, @@ -158,16 +158,39 @@ def __init__( sku_id=sku_id, id=id, ) - self.row = row + + def _generate_underlying( + self, + style: ButtonStyle | None = None, + label: str | None = None, + disabled: bool = False, + custom_id: str | None = None, + url: str | None = None, + emoji: str | GuildEmoji | AppEmoji | PartialEmoji | None = None, + sku_id: int | None = None, + id: int | None = None, + ) -> ButtonComponent: + super()._generate_underlying(ButtonComponent) + return ButtonComponent._raw_construct( + type=ComponentType.button, + custom_id=custom_id or self.custom_id, + url=url or self.url, + disabled=disabled or self.disabled, + label=label or self.label, + style=style or self.style, + emoji=emoji or self.emoji, + sku_id=sku_id or self.sku_id, + id=id or self.id, + ) @property def style(self) -> ButtonStyle: """The style of the button.""" - return self._underlying.style + return self.underlying.style @style.setter def style(self, value: ButtonStyle): - self._underlying.style = value + self.underlying.style = value @property def custom_id(self) -> str | None: @@ -175,7 +198,7 @@ def custom_id(self) -> str | None: If this button is for a URL, it does not have a custom ID. """ - return self._underlying.custom_id + return self.underlying.custom_id @custom_id.setter def custom_id(self, value: str | None): @@ -183,53 +206,53 @@ def custom_id(self, value: str | None): raise TypeError("custom_id must be None or str") if value and len(value) > 100: raise ValueError("custom_id must be 100 characters or fewer") - self._underlying.custom_id = value + self.underlying.custom_id = value self._provided_custom_id = value is not None @property def url(self) -> str | None: """The URL this button sends you to.""" - return self._underlying.url + return self.underlying.url @url.setter def url(self, value: str | None): if value is not None and not isinstance(value, str): raise TypeError("url must be None or str") - self._underlying.url = value + self.underlying.url = value @property def disabled(self) -> bool: """Whether the button is disabled or not.""" - return self._underlying.disabled + return self.underlying.disabled @disabled.setter def disabled(self, value: bool): - self._underlying.disabled = bool(value) + self.underlying.disabled = bool(value) @property def label(self) -> str | None: """The label of the button, if available.""" - return self._underlying.label + return self.underlying.label @label.setter def label(self, value: str | None): if value and len(str(value)) > 80: raise ValueError("label must be 80 characters or fewer") - self._underlying.label = str(value) if value is not None else value + self.underlying.label = str(value) if value is not None else value @property def emoji(self) -> PartialEmoji | None: """The emoji of the button, if available.""" - return self._underlying.emoji + return self.underlying.emoji @emoji.setter def emoji(self, value: str | GuildEmoji | AppEmoji | PartialEmoji | None): # type: ignore if value is None: - self._underlying.emoji = None + self.underlying.emoji = None elif isinstance(value, str): - self._underlying.emoji = PartialEmoji.from_str(value) + self.underlying.emoji = PartialEmoji.from_str(value) elif isinstance(value, _EmojiTag): - self._underlying.emoji = value._to_partial() + self.underlying.emoji = value._to_partial() else: raise TypeError( "expected str, GuildEmoji, AppEmoji, or PartialEmoji, received" @@ -239,14 +262,14 @@ def emoji(self, value: str | GuildEmoji | AppEmoji | PartialEmoji | None): # ty @property def sku_id(self) -> int | None: """The ID of the SKU this button refers to.""" - return self._underlying.sku_id + return self.underlying.sku_id @sku_id.setter def sku_id(self, value: int | None): # type: ignore if value is None: - self._underlying.sku_id = None + self.underlying.sku_id = None elif isinstance(value, int): - self._underlying.sku_id = value + self.underlying.sku_id = value else: raise TypeError(f"expected int or None, received {value.__class__} instead") @@ -281,7 +304,7 @@ def is_persistent(self) -> bool: return super().is_persistent() def refresh_component(self, button: ButtonComponent) -> None: - self._underlying = button + self.underlying = button def button( diff --git a/discord/ui/container.py b/discord/ui/container.py index 4c738c8fdf..0ceb231bd9 100644 --- a/discord/ui/container.py +++ b/discord/ui/container.py @@ -27,9 +27,8 @@ from typing import TYPE_CHECKING, Iterator, TypeVar from ..colour import Colour -from ..components import ActionRow from ..components import Container as ContainerComponent -from ..components import _component_factory +from ..components import MediaGalleryItem, _component_factory from ..enums import ComponentType, SeparatorSpacingSize from ..utils import find, get from .action_row import ActionRow @@ -108,25 +107,40 @@ def __init__( self.items: list[ViewItem] = [] - self._underlying = ContainerComponent._raw_construct( - type=ComponentType.container, + self._underlying = self._generate_underlying( id=id, - components=[], - accent_color=None, + accent_color=colour or color, spoiler=spoiler, ) - self.color = colour or color for i in items: self.add_item(i) def _add_component_from_item(self, item: ViewItem): - self._underlying.components.append(item._underlying) + self.underlying.components.append(item._generate_underlying()) def _set_components(self, items: list[ViewItem]): - self._underlying.components.clear() + self.underlying.components.clear() for item in items: self._add_component_from_item(item) + def _generate_underlying( + self, + accent_color: int | Colour | None = None, + spoiler: bool = False, + id: int | None = None, + ) -> ContainerComponent: + super()._generate_underlying(ContainerComponent) + container = ContainerComponent._raw_construct( + type=ComponentType.container, + id=id or self.id, + components=[], + accent_color=Colour.resolve_value(accent_color or self.colour), + spoiler=spoiler or self.spoiler, + ) + for i in self.items: + container.components.append(i._generate_underlying()) + return container + def add_item(self, item: ViewItem) -> Self: """Adds an item to the container. @@ -176,6 +190,36 @@ def remove_item(self, item: ViewItem | str | int) -> Self: item.parent = None return self + def replace_item( + self, original_item: ViewItem | str | int, new_item: ViewItem + ) -> Self: + """Directly replace an item in this container. + If an :class:`int` is provided, the item will be replaced by ``id``, otherwise by ``custom_id``. + + Parameters + ---------- + original_item: Union[:class:`ViewItem`, :class:`int`, :class:`str`] + The item, item ``id``, or item ``custom_id`` to replace in the container. + new_item: :class:`ViewItem` + The new item to insert into the container. + """ + + if isinstance(original_item, (str, int)): + original_item = self.get_item(original_item) + if not original_item: + raise ValueError(f"Could not find original_item in container.") + try: + if original_item.parent is self: + i = self.items.index(original_item) + new_item.parent = self + self.items[i] = new_item + original_item.parent = None + else: + original_item.parent.replace_item(original_item, new_item) + except ValueError: + raise ValueError(f"Could not find original_item in container.") + return self + def get_item(self, id: str | int) -> ViewItem | None: """Get an item from this container. Roughly equivalent to `utils.get(container.items, ...)`. If an ``int`` is provided, the item will be retrieved by ``id``, otherwise by ``custom_id``. @@ -219,9 +263,9 @@ def add_row( The action row's ID. """ - a = ActionRow(*items, id=id) + row = ActionRow(*items, id=id) - return self.add_item(a) + return self.add_item(row) def add_section( self, @@ -267,7 +311,7 @@ def add_text(self, content: str, id: int | None = None) -> Self: def add_gallery( self, - *items: ViewItem, + *items: MediaGalleryItem, id: int | None = None, ) -> Self: """Adds a :class:`MediaGallery` to the container. @@ -335,27 +379,19 @@ def copy_text(self) -> str: @property def spoiler(self) -> bool: """Whether the container has the spoiler overlay. Defaults to ``False``.""" - return self._underlying.spoiler + return self.underlying.spoiler @spoiler.setter def spoiler(self, spoiler: bool) -> None: - self._underlying.spoiler = spoiler + self.underlying.spoiler = spoiler @property def colour(self) -> Colour | None: - return self._underlying.accent_color + return self.underlying.accent_color @colour.setter def colour(self, value: int | Colour | None): # type: ignore - if value is None or isinstance(value, Colour): - self._underlying.accent_color = value - elif isinstance(value, int): - self._underlying.accent_color = Colour(value=value) - else: - raise TypeError( - "Expected discord.Colour, int, or None but received" - f" {value.__class__.__name__} instead." - ) + self.underlying.accent_color = Colour.resolve_value(value) color = colour @@ -366,7 +402,7 @@ def is_persistent(self) -> bool: return all(item.is_persistent() for item in self.items) def refresh_component(self, component: ContainerComponent) -> None: - self._underlying = component + self.underlying = component i = 0 for y in component.components: x = self.items[i] @@ -413,7 +449,7 @@ def walk_items(self) -> Iterator[ViewItem]: yield item def to_component_dict(self) -> ContainerComponentPayload: - self._set_components(self.items) + self._underlying = self._generate_underlying() return super().to_component_dict() @classmethod diff --git a/discord/ui/core.py b/discord/ui/core.py index 21bb16871a..2c00309710 100644 --- a/discord/ui/core.py +++ b/discord/ui/core.py @@ -27,6 +27,7 @@ import asyncio import time from itertools import groupby +from operator import attrgetter from typing import TYPE_CHECKING, Any, Callable from ..utils import find, get @@ -115,30 +116,68 @@ def _dispatch_timeout(self): def to_components(self) -> list[dict[str, Any]]: return [item.to_component_dict() for item in self.children] - def get_item(self, custom_id: str | int) -> Item | None: - """Gets an item from this structure. Roughly equal to `utils.get(self.children, ...)`. + def get_item(self, custom_id: str | int | None = None, **attrs: Any) -> Item | None: + r"""Gets an item from this structure. Roughly equal to `utils.get(self.children, **attrs)`. If an :class:`int` is provided, the item will be retrieved by ``id``, otherwise by ``custom_id``. This method will also search nested items. + If ``attrs`` are provided, it will check them by logical AND as done in :func:`~utils.get`. + To have a nested attribute search (i.e. search by ``x.y``) then pass in ``x__y`` as the keyword argument. + + Examples + --------- + + Basic usage: + + .. code-block:: python3 + + container = my_view.get(1234) + + Attribute matching: + + .. code-block:: python3 + + button = my_view.get(label='Click me!', style=discord.ButtonStyle.danger) Parameters ---------- - custom_id: Union[:class:`str`, :class:`int`] + custom_id: Optional[Union[:class:`str`, :class:`int`]] The id of the item to get + \*\*attrs + Keyword arguments that denote attributes to search with. Returns ------- Optional[:class:`Item`] - The item with the matching ``custom_id`` or ``id`` if it exists. + The item with the matching ``custom_id``, ``id``, or ``attrs`` if it exists. """ - if not custom_id: + if not (custom_id or attrs): return None - attr = "id" if isinstance(custom_id, int) else "custom_id" - child = find(lambda i: getattr(i, attr, None) == custom_id, self.children) - if not child: + child = None + if custom_id: + attr = "id" if isinstance(custom_id, int) else "custom_id" + child = find(lambda i: getattr(i, attr, None) == custom_id, self.children) + if not child: + for i in self.children: + if hasattr(i, "get_item"): + if child := i.get_item(custom_id): + return child + elif attrs: + _all = all + attrget = attrgetter for i in self.children: + converted = [ + (attrget(attr.replace("__", ".")), value) + for attr, value in attrs.items() + ] + try: + if _all(pred(i) == value for pred, value in converted): + return i + except: + pass if hasattr(i, "get_item"): - if child := i.get_item(custom_id): + if child := i.get_item(custom_id, **attrs): return child + return child def add_item(self, item: Item) -> Self: diff --git a/discord/ui/file.py b/discord/ui/file.py index 32e1ac2f45..17e57c87b1 100644 --- a/discord/ui/file.py +++ b/discord/ui/file.py @@ -69,47 +69,69 @@ class File(ViewItem[V]): def __init__(self, url: str, *, spoiler: bool = False, id: int | None = None): super().__init__() - self.file = UnfurledMediaItem(url) + file = UnfurledMediaItem(url) - self._underlying = FileComponent._raw_construct( - type=ComponentType.file, + self._underlying = self._generate_underlying( id=id, - file=self.file, + file=file, spoiler=spoiler, ) + def _generate_underlying( + self, + file: UnfurledMediaItem | None = None, + spoiler: bool | None = None, + id: int | None = None, + ) -> FileComponent: + super()._generate_underlying(FileComponent) + return FileComponent._raw_construct( + type=ComponentType.file, + id=id or self.id, + file=file or self.file, + spoiler=spoiler if spoiler is not None else self.spoiler, + ) + + @property + def file(self) -> UnfurledMediaItem: + """The file's unerlying media item.""" + return self.underlying.file + + @file.setter + def url(self, value: UnfurledMediaItem) -> None: + self.underlying.file = value + @property def url(self) -> str: - """The URL of this file's media. This must be an ``attachment://`` URL that references a :class:`~discord.File`.""" - return self._underlying.file and self._underlying.file.url + """The URL of this file's underlying media. This must be an ``attachment://`` URL that references a :class:`~discord.File`.""" + return self.underlying.file and self.underlying.file.url @url.setter def url(self, value: str) -> None: - self._underlying.file.url = value + self.underlying.file.url = value @property def spoiler(self) -> bool: """Whether the file has the spoiler overlay. Defaults to ``False``.""" - return self._underlying.spoiler + return self.underlying.spoiler @spoiler.setter def spoiler(self, spoiler: bool) -> None: - self._underlying.spoiler = spoiler + self.underlying.spoiler = spoiler @property def name(self) -> str: """The name of this file, if provided by Discord.""" - return self._underlying.name + return self.underlying.name @property def size(self) -> int: """The size of this file in bytes, if provided by Discord.""" - return self._underlying.size + return self.underlying.size def refresh_component(self, component: FileComponent) -> None: - original = self._underlying.file + original = self.underlying.file component.file._static_url = original._static_url - self._underlying = component + self.underlying = component def to_component_dict(self) -> FileComponentPayload: return super().to_component_dict() diff --git a/discord/ui/file_upload.py b/discord/ui/file_upload.py index 1f4222b0e0..0ab76a5062 100644 --- a/discord/ui/file_upload.py +++ b/discord/ui/file_upload.py @@ -65,16 +65,15 @@ def __init__( if not isinstance(required, bool): raise TypeError(f"required must be bool not {required.__class__.__name__}") # type: ignore custom_id = os.urandom(16).hex() if custom_id is None else custom_id + self._attachments: list[Attachment] | None = None - self._underlying: FileUploadComponent = FileUploadComponent._raw_construct( - type=ComponentType.file_upload, + self._underlying: FileUploadComponent = self._generate_underlying( custom_id=custom_id, min_values=min_values, max_values=max_values, required=required, id=id, ) - self._attachments: list[Attachment] | None = None def __repr__(self) -> str: attrs = " ".join( @@ -82,19 +81,37 @@ def __repr__(self) -> str: ) return f"<{self.__class__.__name__} {attrs}>" + def _generate_underlying( + self, + custom_id: str | None = None, + min_values: int | None = None, + max_values: int | None = None, + required: bool = None, + id: int | None = None, + ) -> FileUploadComponent: + super()._generate_underlying(FileUploadComponent) + return FileUploadComponent._raw_construct( + type=ComponentType.file_upload, + custom_id=custom_id or self.custom_id, + min_values=min_values or self.min_values, + max_values=max_values or self.max_values, + required=required if required is not None else self.required, + id=id or self.id, + ) + @property def type(self) -> ComponentType: - return self._underlying.type + return self.underlying.type @property def id(self) -> int | None: """The ID of this component. If not provided by the user, it is set sequentially by Discord.""" - return self._underlying.id + return self.underlying.id @property def custom_id(self) -> str: """The custom id that gets received during an interaction.""" - return self._underlying.custom_id + return self.underlying.custom_id @custom_id.setter def custom_id(self, value: str): @@ -102,12 +119,12 @@ def custom_id(self, value: str): raise TypeError( f"custom_id must be None or str not {value.__class__.__name__}" ) - self._underlying.custom_id = value + self.underlying.custom_id = value @property def min_values(self) -> int | None: """The minimum number of files that must be uploaded. Defaults to 0.""" - return self._underlying.min_values + return self.underlying.min_values @min_values.setter def min_values(self, value: int | None): @@ -115,12 +132,12 @@ def min_values(self, value: int | None): raise TypeError(f"min_values must be None or int not {value.__class__.__name__}") # type: ignore if value and (value < 0 or value > 10): raise ValueError("min_values must be between 0 and 10") - self._underlying.min_values = value + self.underlying.min_values = value @property def max_values(self) -> int | None: """The maximum number of files that can be uploaded.""" - return self._underlying.max_values + return self.underlying.max_values @max_values.setter def max_values(self, value: int | None): @@ -128,18 +145,18 @@ def max_values(self, value: int | None): raise TypeError(f"max_values must be None or int not {value.__class__.__name__}") # type: ignore if value and (value < 1 or value > 10): raise ValueError("max_values must be between 1 and 10") - self._underlying.max_values = value + self.underlying.max_values = value @property def required(self) -> bool: """Whether the input file upload is required or not. Defaults to ``True``.""" - return self._underlying.required + return self.underlying.required @required.setter def required(self, value: bool): if not isinstance(value, bool): raise TypeError(f"required must be bool not {value.__class__.__name__}") # type: ignore - self._underlying.required = bool(value) + self.underlying.required = bool(value) @property def values(self) -> list[Attachment] | None: @@ -147,7 +164,7 @@ def values(self) -> list[Attachment] | None: return self._attachments def to_component_dict(self) -> FileUploadComponentPayload: - return self._underlying.to_dict() + return self.underlying.to_dict() def refresh_from_modal(self, interaction: Interaction, data: dict) -> None: values = data.get("values", []) @@ -158,3 +175,16 @@ def refresh_from_modal(self, interaction: Interaction, data: dict) -> None: ) for attachment_id in values ] + + @classmethod + def from_component( + cls: type[FileUpload], component: FileUploadComponent + ) -> FileUpload: + + return cls( + custom_id=component.custom_id, + min_values=component.min_values, + max_values=component.max_values, + required=component.required, + id=component.id, + ) diff --git a/discord/ui/input_text.py b/discord/ui/input_text.py index 9f8e35bffc..47470b3927 100644 --- a/discord/ui/input_text.py +++ b/discord/ui/input_text.py @@ -116,9 +116,11 @@ def __init__( f"expected custom_id to be str, not {custom_id.__class__.__name__}" ) custom_id = os.urandom(16).hex() if custom_id is None else custom_id + self._input_value = False + self.row = row + self._rendered_row: int | None = None - self._underlying = InputTextComponent._raw_construct( - type=ComponentType.input_text, + self._underlying = self._generate_underlying( style=style, custom_id=custom_id, label=label, @@ -129,9 +131,6 @@ def __init__( value=value, id=id, ) - self._input_value = False - self.row = row - self._rendered_row: int | None = None def __repr__(self) -> str: attrs = " ".join( @@ -139,10 +138,36 @@ def __repr__(self) -> str: ) return f"<{self.__class__.__name__} {attrs}>" + def _generate_underlying( + self, + style: InputTextStyle | None = None, + custom_id: str | None = None, + label: str | None = None, + placeholder: str | None = None, + min_length: int | None = None, + max_length: int | None = None, + required: bool | None = True, + value: str | None = None, + id: int | None = None, + ) -> InputTextComponent: + super()._generate_underlying(InputTextComponent) + return InputTextComponent._raw_construct( + type=ComponentType.input_text, + style=style or self.style, + custom_id=custom_id or self.custom_id, + label=label or self.label, + placeholder=placeholder or self.placeholder, + min_length=min_length or self.min_length, + max_length=max_length or self.max_length, + required=required or self.required, + value=value or self.value, + id=id or self.id, + ) + @property def style(self) -> InputTextStyle: """The style of the input text field.""" - return self._underlying.style + return self.underlying.style @style.setter def style(self, value: InputTextStyle): @@ -150,12 +175,12 @@ def style(self, value: InputTextStyle): raise TypeError( f"style must be of type InputTextStyle not {value.__class__.__name__}" ) - self._underlying.style = value + self.underlying.style = value @property def custom_id(self) -> str: """The ID of the input text field that gets received during an interaction.""" - return self._underlying.custom_id + return self.underlying.custom_id @custom_id.setter def custom_id(self, value: str): @@ -163,12 +188,12 @@ def custom_id(self, value: str): raise TypeError( f"custom_id must be None or str not {value.__class__.__name__}" ) - self._underlying.custom_id = value + self.underlying.custom_id = value @property def label(self) -> str: """The label of the input text field.""" - return self._underlying.label + return self.underlying.label @label.setter def label(self, value: str): @@ -176,12 +201,12 @@ def label(self, value: str): raise TypeError(f"label should be str not {value.__class__.__name__}") if len(value) > 45: raise ValueError("label must be 45 characters or fewer") - self._underlying.label = value + self.underlying.label = value @property def placeholder(self) -> str | None: """The placeholder text that is shown before anything is entered, if any.""" - return self._underlying.placeholder + return self.underlying.placeholder @placeholder.setter def placeholder(self, value: str | None): @@ -189,12 +214,12 @@ def placeholder(self, value: str | None): raise TypeError(f"placeholder must be None or str not {value.__class__.__name__}") # type: ignore if value and len(value) > 100: raise ValueError("placeholder must be 100 characters or fewer") - self._underlying.placeholder = value + self.underlying.placeholder = value @property def min_length(self) -> int | None: """The minimum number of characters that must be entered. Defaults to 0.""" - return self._underlying.min_length + return self.underlying.min_length @min_length.setter def min_length(self, value: int | None): @@ -202,12 +227,12 @@ def min_length(self, value: int | None): raise TypeError(f"min_length must be None or int not {value.__class__.__name__}") # type: ignore if value and (value < 0 or value) > 4000: raise ValueError("min_length must be between 0 and 4000") - self._underlying.min_length = value + self.underlying.min_length = value @property def max_length(self) -> int | None: """The maximum number of characters that can be entered.""" - return self._underlying.max_length + return self.underlying.max_length @max_length.setter def max_length(self, value: int | None): @@ -215,18 +240,18 @@ def max_length(self, value: int | None): raise TypeError(f"min_length must be None or int not {value.__class__.__name__}") # type: ignore if value and (value <= 0 or value > 4000): raise ValueError("max_length must be between 1 and 4000") - self._underlying.max_length = value + self.underlying.max_length = value @property def required(self) -> bool | None: """Whether the input text field is required or not. Defaults to ``True``.""" - return self._underlying.required + return self.underlying.required @required.setter def required(self, value: bool | None): if not isinstance(value, bool): raise TypeError(f"required must be bool not {value.__class__.__name__}") # type: ignore - self._underlying.required = bool(value) + self.underlying.required = bool(value) @property def value(self) -> str | None: @@ -234,7 +259,7 @@ def value(self) -> str | None: if self._input_value is not False: # only False on init, otherwise the value was either set or cleared return self._input_value # type: ignore - return self._underlying.value + return self.underlying.value @value.setter def value(self, value: str | None): @@ -242,7 +267,7 @@ def value(self, value: str | None): raise TypeError(f"value must be None or str not {value.__class__.__name__}") # type: ignore if value and len(str(value)) > 4000: raise ValueError("value must be 4000 characters or fewer") - self._underlying.value = value + self.underlying.value = value @property def width(self) -> int: @@ -259,5 +284,22 @@ def refresh_from_modal( ) -> None: return self.refresh_state(data) + @classmethod + def from_component( + cls: type[InputText], component: InputTextComponent + ) -> InputText: + + return cls( + style=component.style, + custom_id=component.custom_id, + label=component.label, + placeholder=component.placeholder, + min_length=component.min_length, + max_length=component.max_length, + required=component.required, + value=component.value, + id=component.id, + ) + TextInput = InputText diff --git a/discord/ui/item.py b/discord/ui/item.py index 6f21c7684c..a38148bb19 100644 --- a/discord/ui/item.py +++ b/discord/ui/item.py @@ -68,12 +68,12 @@ def __init__(self): self.parent: Item | ItemInterface | None = None def to_component_dict(self) -> dict[str, Any]: - if not self._underlying: + if not self.underlying: raise NotImplementedError - return self._underlying.to_dict() + return self.underlying.to_dict() def refresh_component(self, component: Component) -> None: - self._underlying = component + self.underlying = component def refresh_state(self, interaction: Interaction) -> None: return None @@ -82,11 +82,24 @@ def refresh_state(self, interaction: Interaction) -> None: def from_component(cls: type[I], component: Component) -> I: return cls() + @property + def underlying(self) -> Component: + return self._underlying + + @underlying.setter + def underlying(self, value: Component) -> None: + self._underlying = value + @property def type(self) -> ComponentType: - if not self._underlying: + if not self.underlying: raise NotImplementedError - return self._underlying.type + return self.underlying.type + + def _generate_underlying(self, cls: type[Component]) -> Component: + if not self._underlying: + self._underlying = cls._raw_construct() + return self._underlying def is_dispatchable(self) -> bool: return False @@ -117,13 +130,13 @@ def id(self) -> int | None: Optional[:class:`int`] The ID of this item, or ``None`` if the user didn't set one. """ - return self._underlying and self._underlying.id + return self.underlying and self.underlying.id @id.setter def id(self, value) -> None: - if not self._underlying: + if not self.underlying: return - self._underlying.id = value + self.underlying.id = value class ViewItem(Item[V]): diff --git a/discord/ui/label.py b/discord/ui/label.py index 1fba7a0b7a..919a2aa1e2 100644 --- a/discord/ui/label.py +++ b/discord/ui/label.py @@ -95,10 +95,8 @@ def __init__( self.item: ModalItem = None - self._underlying = LabelComponent._raw_construct( - type=ComponentType.label, + self._underlying = self._generate_underlying( id=id, - component=None, label=label, description=description, ) @@ -113,7 +111,26 @@ def modal(self, value): self.item.modal = value def _set_component_from_item(self, item: ModalItem): - self._underlying.component = item._underlying + self.underlying.component = item._generate_underlying() + + def _generate_underlying( + self, + label: str | None = None, + description: str | None = None, + id: int | None = None, + ) -> LabelComponent: + super()._generate_underlying(LabelComponent) + label = LabelComponent._raw_construct( + type=ComponentType.label, + id=id or self.id, + component=None, + label=label or self.label, + description=description or self.description, + ) + + if self.item: + label.component = self.item._generate_underlying() + return label def set_item(self, item: ModalItem) -> Self: """Set this label's item. @@ -372,20 +389,20 @@ def set_file_upload( @property def label(self) -> str: """The label text. Must be 45 characters or fewer.""" - return self._underlying.label + return self.underlying.label @label.setter def label(self, value: str) -> None: - self._underlying.label = value + self.underlying.label = value @property def description(self) -> str | None: """The description for this label. Must be 100 characters or fewer.""" - return self._underlying.description + return self.underlying.description @description.setter def description(self, value: str | None) -> None: - self._underlying.description = value + self.underlying.description = value def is_dispatchable(self) -> bool: return self.item.is_dispatchable() @@ -394,7 +411,7 @@ def is_persistent(self) -> bool: return self.item.is_persistent() def refresh_component(self, component: LabelComponent) -> None: - self._underlying = component + self.underlying = component self.item.refresh_component(component.component) def walk_items(self) -> Iterator[ModalItem]: @@ -415,8 +432,8 @@ def from_component(cls: type[L], component: LabelComponent) -> L: item = _component_to_item(component.component) return cls( - item, - id=component.id, label=component.label, + item=item, + id=component.id, description=component.description, ) diff --git a/discord/ui/media_gallery.py b/discord/ui/media_gallery.py index 838a837d36..a81873677e 100644 --- a/discord/ui/media_gallery.py +++ b/discord/ui/media_gallery.py @@ -65,13 +65,21 @@ class MediaGallery(ViewItem[V]): def __init__(self, *items: MediaGalleryItem, id: int | None = None): super().__init__() - self._underlying = MediaGalleryComponent._raw_construct( - type=ComponentType.media_gallery, id=id, items=[i for i in items] + self._underlying = self._generate_underlying(id=id, items=items) + + def _generate_underlying( + self, id: int | None = None, items: list[MediaGalleryItem] | None = None + ) -> MediaGalleryComponent: + super()._generate_underlying(MediaGalleryComponent) + return MediaGalleryComponent._raw_construct( + type=ComponentType.media_gallery, + id=id or self.id, + items=[i for i in items] if items else [i for i in self.items or []], ) @property def items(self): - return self._underlying.items + return self.underlying.items def append_item(self, item: MediaGalleryItem) -> Self: """Adds a :attr:`MediaGalleryItem` to the gallery. @@ -95,7 +103,7 @@ def append_item(self, item: MediaGalleryItem) -> Self: if not isinstance(item, MediaGalleryItem): raise TypeError(f"expected MediaGalleryItem not {item.__class__!r}") - self._underlying.items.append(item) + self.underlying.items.append(item) return self def add_item( @@ -129,7 +137,39 @@ def add_item( return self.append_item(item) + def remove_item(self, index: int) -> Self: + """Removes an item from the gallery. + + Parameters + ---------- + index: :class:`int` + The index of the item to remove from the gallery. + """ + + try: + self.items.pop(index) + except IndexError: + pass + return self + + def replace_item(self, index: int, new_item: MediaGalleryItem) -> Self: + """Directly replace an item in this gallery by index. + + Parameters + ---------- + original_item: :class:`int` + The index of the item to replace in this gallery. + new_item: :class:`MediaGalleryItem` + The new item to insert into the gallery. + """ + + if not isinstance(new_item, MediaGalleryItem): + raise TypeError(f"expected MediaGalleryItem not {new_item.__class__!r}") + self.items[index] = new_item + return self + def to_component_dict(self) -> MediaGalleryComponentPayload: + self.underlying = self._generate_underlying() return super().to_component_dict() @classmethod diff --git a/discord/ui/modal.py b/discord/ui/modal.py index 8620cec363..1051156f80 100644 --- a/discord/ui/modal.py +++ b/discord/ui/modal.py @@ -30,7 +30,7 @@ import time from functools import partial from itertools import groupby -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, Iterator, TypeVar from ..enums import ComponentType from ..utils import find @@ -238,8 +238,6 @@ async def on_error(self, error: Exception, interaction: Interaction) -> None: ---------- error: :class:`Exception` The exception that was raised. - modal: :class:`BaseModal` - The modal that failed the dispatch. interaction: :class:`~discord.Interaction` The interaction that led to the failure. """ @@ -251,6 +249,13 @@ async def on_timeout(self) -> None: A callback that is called when a modal's timeout elapses without being explicitly stopped. """ + def walk_children(self) -> Iterator[ModalItem]: + for item in self.children: + if hasattr(item, "walk_items"): + yield from item.walk_items() + else: + yield item + class Modal(BaseModal): """Represents a legacy UI modal for InputText components. diff --git a/discord/ui/section.py b/discord/ui/section.py index 069b600ea5..a2d62da642 100644 --- a/discord/ui/section.py +++ b/discord/ui/section.py @@ -93,11 +93,8 @@ def __init__( self.items: list[ViewItem] = [] self.accessory: ViewItem | None = None - self._underlying = SectionComponent._raw_construct( - type=ComponentType.section, + self._underlying = self._generate_underlying( id=id, - components=[], - accessory=None, ) for func in self.__section_accessory_item__: item: ViewItem = func.__discord_ui_model_type__( @@ -112,13 +109,27 @@ def __init__( self.add_item(i) def _add_component_from_item(self, item: ViewItem): - self._underlying.components.append(item._underlying) + self.underlying.components.append(item.underlying) def _set_components(self, items: list[ViewItem]): - self._underlying.components.clear() + self.underlying.components.clear() for item in items: self._add_component_from_item(item) + def _generate_underlying(self, id: int | None = None) -> SectionComponent: + super()._generate_underlying(SectionComponent) + section = SectionComponent._raw_construct( + type=ComponentType.section, + id=id or self.id, + components=[], + accessory=None, + ) + for i in self.items: + section.components.append(i._generate_underlying()) + if self.accessory: + section.accessory = self.accessory._generate_underlying() + return section + def add_item(self, item: ViewItem) -> Self: """Adds an item to the section. @@ -159,12 +170,48 @@ def remove_item(self, item: ViewItem | str | int) -> Self: if isinstance(item, (str, int)): item = self.get_item(item) try: - self.items.remove(item) + if item is self.accessory: + self.accessory = None + else: + self.items.remove(item) except ValueError: pass item.parent = None return self + def replace_item( + self, original_item: ViewItem | str | int, new_item: ViewItem + ) -> Self: + """Directly replace an item in this section. + If an :class:`int` is provided, the item will be replaced by ``id``, otherwise by ``custom_id``. + + Parameters + ---------- + original_item: Union[:class:`ViewItem`, :class:`int`, :class:`str`] + The item, item ``id``, or item ``custom_id`` to replace in the section. + new_item: :class:`ViewItem` + The new item to insert into the section. + """ + + if not isinstance(new_item, ViewItem): + raise TypeError(f"expected ViewItem not {new_item.__class__!r}") + + if isinstance(original_item, (str, int)): + original_item = self.get_item(original_item) + if not original_item: + raise ValueError(f"Could not find original_item in section.") + try: + if original_item is self.accessory: + self.accessory = new_item + else: + i = self.items.index(original_item) + self.items[i] = new_item + original_item.parent = None + new_item.parent = self + except ValueError: + raise ValueError(f"Could not find original_item in section.") + return self + def get_item(self, id: int | str) -> ViewItem | None: """Get an item from this section. Alias for `utils.get(section.walk_items(), ...)`. If an ``int`` is provided, it will be retrieved by ``id``, otherwise it will check the accessory's ``custom_id``. @@ -231,7 +278,7 @@ def set_accessory(self, item: ViewItem) -> Self: item.parent = self self.accessory = item - self._underlying.accessory = item._underlying + self.underlying.accessory = item._generate_underlying() return self def set_thumbnail( @@ -275,7 +322,7 @@ def is_persistent(self) -> bool: return self.accessory.is_persistent() def refresh_component(self, component: SectionComponent) -> None: - self._underlying = component + self.underlying = component for x, y in zip(self.items, component.components): x.refresh_component(y) if self.accessory and component.accessory: @@ -323,9 +370,7 @@ def walk_items(self) -> Iterator[ViewItem]: yield from r def to_component_dict(self) -> SectionComponentPayload: - self._set_components(self.items) - if self.accessory: - self.set_accessory(self.accessory) + self._underlying = self._generate_underlying() return super().to_component_dict() @classmethod diff --git a/discord/ui/select.py b/discord/ui/select.py index 9d8cb9c612..2c24cd18fd 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -276,7 +276,8 @@ def __init__( self._provided_custom_id = custom_id is not None custom_id = os.urandom(16).hex() if custom_id is None else custom_id - self._underlying: SelectMenu = SelectMenu._raw_construct( + self.row = row + self._underlying: SelectMenu = self._generate_underlying( custom_id=custom_id, type=select_type, placeholder=placeholder, @@ -289,7 +290,37 @@ def __init__( required=required, default_values=self._handle_default_values(default_values, select_type), ) - self.row = row + + def _generate_underlying( + self, + type: ComponentType | None = None, + custom_id: str | None = None, + placeholder: str | None = None, + min_values: int = None, + max_values: int = None, + options: list[SelectOption] | None = None, + channel_types: list[ChannelType] | None = None, + disabled: bool = None, + id: int | None = None, + required: bool | None = None, + default_values: Sequence[SelectDefaultValue | ST] | None = None, + ) -> SelectMenu: + super()._generate_underlying(SelectMenu) + return SelectMenu._raw_construct( + custom_id=custom_id or self.custom_id, + type=type or self.type, + placeholder=placeholder or self.placeholder, + min_values=min_values if min_values is not None else self.min_values, + max_values=max_values if max_values is not None else self.max_values, + disabled=disabled if disabled is not None else self.disabled, + options=options if options is not None else self.options, + channel_types=( + channel_types if channel_types is not None else self.channel_types + ), + id=id or self.id, + required=required if required is not None else self.required, + default_values=default_values or self.default_values or [], + ) def _handle_default_values( self, @@ -338,7 +369,7 @@ def _handle_default_values( @property def custom_id(self) -> str: """The ID of the select menu that gets received during an interaction.""" - return self._underlying.custom_id + return self.underlying.custom_id @custom_id.setter def custom_id(self, value: str): @@ -346,13 +377,13 @@ def custom_id(self, value: str): raise TypeError("custom_id must be None or str") if len(value) > 100: raise ValueError("custom_id must be 100 characters or fewer") - self._underlying.custom_id = value + self.underlying.custom_id = value self._provided_custom_id = value is not None @property def placeholder(self) -> str | None: """The placeholder text that is shown if nothing is selected, if any.""" - return self._underlying.placeholder + return self.underlying.placeholder @placeholder.setter def placeholder(self, value: str | None): @@ -361,74 +392,74 @@ def placeholder(self, value: str | None): if value and len(value) > 150: raise ValueError("placeholder must be 150 characters or fewer") - self._underlying.placeholder = value + self.underlying.placeholder = value @property def min_values(self) -> int: """The minimum number of items that must be chosen for this select menu.""" - return self._underlying.min_values + return self.underlying.min_values @min_values.setter def min_values(self, value: int): if value < 0 or value > 25: raise ValueError("min_values must be between 0 and 25") - self._underlying.min_values = int(value) + self.underlying.min_values = int(value) @property def max_values(self) -> int: """The maximum number of items that must be chosen for this select menu.""" - return self._underlying.max_values + return self.underlying.max_values @max_values.setter def max_values(self, value: int): if value < 1 or value > 25: raise ValueError("max_values must be between 1 and 25") - self._underlying.max_values = int(value) + self.underlying.max_values = int(value) @property def disabled(self) -> bool: """Whether the select is disabled or not.""" - return self._underlying.disabled + return self.underlying.disabled @property def required(self) -> bool: """Whether the select is required or not. Only applicable in modal selects.""" - return self._underlying.required + return self.underlying.required @required.setter def required(self, value: bool): - self._underlying.required = value + self.underlying.required = value @disabled.setter def disabled(self, value: bool): - self._underlying.disabled = bool(value) + self.underlying.disabled = bool(value) @property def channel_types(self) -> list[ChannelType]: """A list of channel types that can be selected in this menu.""" - return self._underlying.channel_types + return self.underlying.channel_types @channel_types.setter def channel_types(self, value: list[ChannelType]): - if self._underlying.type is not ComponentType.channel_select: + if self.underlying.type is not ComponentType.channel_select: raise InvalidArgument("channel_types can only be set on channel selects") - self._underlying.channel_types = value + self.underlying.channel_types = value @property def options(self) -> list[SelectOption]: """A list of options that can be selected in this menu.""" - return self._underlying.options + return self.underlying.options @options.setter def options(self, value: list[SelectOption]): - if self._underlying.type is not ComponentType.string_select: + if self.underlying.type is not ComponentType.string_select: raise InvalidArgument("options can only be set on string selects") if not isinstance(value, list): raise TypeError("options must be a list of SelectOption") if not all(isinstance(obj, SelectOption) for obj in value): raise TypeError("all list items must subclass SelectOption") - self._underlying.options = value + self.underlying.options = value @property def default_values(self) -> list[SelectDefaultValue]: @@ -437,14 +468,14 @@ def default_values(self) -> list[SelectDefaultValue]: .. versionadded:: 2.7 """ - return self._underlying.default_values + return self.underlying.default_values @default_values.setter def default_values( self, values: Sequence[SelectDefaultValue | Snowflake] | None ) -> None: default_values = self._handle_default_values(values, self.type) - self._underlying.default_values = default_values + self.underlying.default_values = default_values def add_default_value( self, @@ -553,7 +584,7 @@ def append_default_value( f"expected a SelectDefaultValue object, got {value.__class__.__name__}" ) - self._underlying.default_values.append(value) + self.underlying.default_values.append(value) return self def add_option( @@ -592,7 +623,7 @@ def add_option( ValueError The number of options exceeds 25. """ - if self._underlying.type is not ComponentType.string_select: + if self.underlying.type is not ComponentType.string_select: raise Exception("options can only be set on string selects") option = SelectOption( @@ -618,13 +649,13 @@ def append_option(self, option: SelectOption) -> Self: ValueError The number of options exceeds 25. """ - if self._underlying.type is not ComponentType.string_select: + if self.underlying.type is not ComponentType.string_select: raise Exception("options can only be set on string selects") - if len(self._underlying.options) > 25: + if len(self.underlying.options) > 25: raise ValueError("maximum number of options already provided") - self._underlying.options.append(option) + self.underlying.options.append(option) return self @property @@ -636,7 +667,7 @@ def values(self) -> list[ST]: if self._interaction is None or self._interaction.data is None: # The select has not been interacted with yet return [] - select_type = self._underlying.type + select_type = self.underlying.type if select_type is ComponentType.string_select: return self._selected_values # type: ignore # ST is str resolved = [] @@ -710,7 +741,7 @@ def to_component_dict(self) -> SelectMenuPayload: return super().to_component_dict() def refresh_component(self, component: SelectMenu) -> None: - self._underlying = component + self.underlying = component def refresh_state(self, interaction: Interaction | dict) -> None: data: ComponentInteractionData = ( diff --git a/discord/ui/separator.py b/discord/ui/separator.py index 2ddfae8af2..9277211f84 100644 --- a/discord/ui/separator.py +++ b/discord/ui/separator.py @@ -72,30 +72,43 @@ def __init__( ): super().__init__() - self._underlying = SeparatorComponent._raw_construct( - type=ComponentType.separator, + self._underlying = self._generate_underlying( id=id, divider=divider, spacing=spacing, ) + def _generate_underlying( + self, + divider: bool | None = None, + spacing: SeparatorSpacingSize | None = None, + id: int | None = None, + ) -> SeparatorComponent: + super()._generate_underlying(SeparatorComponent) + return SeparatorComponent._raw_construct( + type=ComponentType.separator, + id=id or self.id, + divider=divider if divider is not None else self.divider, + spacing=spacing or self.spacing, + ) + @property def divider(self) -> bool: """Whether the separator is a divider. Defaults to ``True``.""" - return self._underlying.divider + return self.underlying.divider @divider.setter def divider(self, value: bool) -> None: - self._underlying.divider = value + self.underlying.divider = value @property def spacing(self) -> SeparatorSpacingSize: """The spacing size of the separator. Defaults to :attr:`~discord.SeparatorSpacingSize.small`.""" - return self._underlying.spacing + return self.underlying.spacing @spacing.setter def spacing(self, value: SeparatorSpacingSize) -> None: - self._underlying.spacing = value + self.underlying.spacing = value def to_component_dict(self) -> SeparatorComponentPayload: return super().to_component_dict() diff --git a/discord/ui/text_display.py b/discord/ui/text_display.py index 76ab9dbc50..be5cc83324 100644 --- a/discord/ui/text_display.py +++ b/discord/ui/text_display.py @@ -70,20 +70,31 @@ def __init__( ): super().__init__() - self._underlying = TextDisplayComponent._raw_construct( - type=ComponentType.text_display, + self._underlying = self._generate_underlying( id=id, content=content, ) + def _generate_underlying( + self, + content: str | None = None, + id: int | None = None, + ) -> TextDisplayComponent: + super()._generate_underlying(TextDisplayComponent) + return TextDisplayComponent._raw_construct( + type=ComponentType.text_display, + id=id or self.id, + content=content or self.content, + ) + @property def content(self) -> str: """The text display's content.""" - return self._underlying.content + return self.underlying.content @content.setter def content(self, value: str) -> None: - self._underlying.content = value + self.underlying.content = value def to_component_dict(self) -> TextDisplayComponentPayload: return super().to_component_dict() diff --git a/discord/ui/thumbnail.py b/discord/ui/thumbnail.py index 61c44bd2ce..648efef735 100644 --- a/discord/ui/thumbnail.py +++ b/discord/ui/thumbnail.py @@ -78,41 +78,65 @@ def __init__( media = UnfurledMediaItem(url) - self._underlying = ThumbnailComponent._raw_construct( - type=ComponentType.thumbnail, + self._underlying = self._generate_underlying( id=id, media=media, description=description, spoiler=spoiler, ) + def _generate_underlying( + self, + media: UnfurledMediaItem | None = None, + description: str | None = None, + spoiler: bool | None = False, + id: int | None = None, + ) -> ThumbnailComponent: + super()._generate_underlying(ThumbnailComponent) + return ThumbnailComponent._raw_construct( + type=ComponentType.thumbnail, + id=id or self.id, + media=media or self.media, + description=description or self.description, + spoiler=spoiler if spoiler is not None else self.spoiler, + ) + + @property + def media(self) -> UnfurledMediaItem: + """The thumbnail's unerlying media item.""" + return self.underlying.media + + @media.setter + def url(self, value: UnfurledMediaItem) -> None: + self.underlying.media = value + @property def url(self) -> str: """The URL of this thumbnail's media. This can either be an arbitrary URL or an ``attachment://`` URL.""" - return self._underlying.media and self._underlying.media.url + return self.underlying.media and self.underlying.media.url @url.setter def url(self, value: str) -> None: - self._underlying.media.url = value + self.underlying.media.url = value @property def description(self) -> str | None: """The thumbnail's description, up to 1024 characters.""" - return self._underlying.description + return self.underlying.description @description.setter def description(self, description: str | None) -> None: - self._underlying.description = description + self.underlying.description = description @property def spoiler(self) -> bool: """Whether the thumbnail has the spoiler overlay. Defaults to ``False``.""" - return self._underlying.spoiler + return self.underlying.spoiler @spoiler.setter def spoiler(self, spoiler: bool) -> None: - self._underlying.spoiler = spoiler + self.underlying.spoiler = spoiler def to_component_dict(self) -> ThumbnailComponentPayload: return super().to_component_dict() diff --git a/discord/ui/view.py b/discord/ui/view.py index ec4c0cd8a7..b7c6e6214d 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -47,6 +47,8 @@ from ..components import Component from ..components import Container as ContainerComponent from ..components import FileComponent +from ..components import FileUpload as FileUploadComponent +from ..components import InputText as InputTextComponent from ..components import Label as LabelComponent from ..components import MediaGallery as MediaGalleryComponent from ..components import Section as SectionComponent @@ -55,7 +57,7 @@ from ..components import TextDisplay as TextDisplayComponent from ..components import Thumbnail as ThumbnailComponent from ..components import _component_factory -from ..enums import ChannelType +from ..enums import ChannelType, SeparatorSpacingSize from ..utils import find from .core import ItemInterface from .item import ItemCallbackType, ViewItem @@ -70,6 +72,7 @@ if TYPE_CHECKING: + from ..components import MediaGalleryItem from ..interactions import Interaction, InteractionMessage from ..message import Message from ..state import ConnectionState @@ -142,6 +145,14 @@ def _component_to_item(component: Component) -> ViewItem[V]: from .label import Label return Label.from_component(component) + if isinstance(component, InputTextComponent): + from .input_text import InputText + + return InputText.from_component(component) + if isinstance(component, FileUploadComponent): + from .file_upload import FileUpload + + return FileUpload.from_component(component) return ViewItem.from_component(component) @@ -362,7 +373,7 @@ def is_components_v2(self) -> bool: A view containing V2 components cannot be sent alongside message content or embeds. """ - return any([item._underlying.is_v2() for item in self.children]) + return any([item.underlying.is_v2() for item in self.children]) async def _scheduled_task(self, item: ViewItem[V], interaction: Interaction): try: @@ -690,11 +701,11 @@ def add_item(self, item: ViewItem[V]) -> Self: or the row the item is trying to be added to is full. """ - if item._underlying.is_v2(): + if item.underlying.is_v2(): raise ValueError( f"cannot use V2 components in View. Use DesignerView instead." ) - if isinstance(item._underlying, ActionRowComponent): + if isinstance(item.underlying, ActionRowComponent): for i in item.children: self.add_item(i) return self @@ -703,7 +714,7 @@ def add_item(self, item: ViewItem[V]) -> Self: self.__weights.add_item(item) return self - def remove_item(self, item: ViewItem[V] | int | str) -> None: + def remove_item(self, item: ViewItem[V] | int | str) -> Self: """Removes an item from the view. If an :class:`int` or :class:`str` is passed, the item will be removed by Item ``id`` or ``custom_id`` respectively. @@ -720,7 +731,7 @@ def remove_item(self, item: ViewItem[V] | int | str) -> None: pass return self - def clear_items(self) -> None: + def clear_items(self) -> Self: """Removes all items from the view.""" super().clear_items() self.__weights.clear() @@ -813,6 +824,14 @@ def __init__( *items, timeout=timeout, disable_on_timeout=disable_on_timeout, store=store ) + @property + def items(self) -> list[ViewItem[V]]: + return self.children + + @items.setter + def items(self, value: list[ViewItem[V]]) -> None: + self.children = value + @classmethod def from_message( cls, message: Message, /, *, timeout: float | None = 180.0 @@ -871,13 +890,37 @@ def from_dict( view.add_item(_component_to_item(component)) return view - def add_item(self, item: ViewItem[V]) -> Self: + def add_item( + self, + item: ViewItem[V], + *, + index: int | None = None, + before: ViewItem[V] | str | int | None = None, + after: ViewItem[V] | str | int | None = None, + into: ViewItem[V] | str | int | None = None, + ) -> Self: """Adds an item to the view. + .. warning:: + + You may specify only **one** of ``index``, ``before``, & ``after``. ``into`` will work together with those parameters. + + .. versionchanged:: 2.7.1 + Added new parameters ``index``, ``before``, ``after``, & ``into``. + Parameters ---------- item: :class:`ViewItem` The item to add to the view. + index: Optional[class:`int`] + Add the new item at the specific index of :attr:`children`. Same behavior as Python's :func:`~list.insert`. + before: Optional[Union[:class:`ViewItem`, :class:`int`, :class:`str`]] + Add the new item **before** the specified item. If an :class:`int` is provided, the item will be detected by ``id``, otherwise by ``custom_id``. + after: Optional[Union[:class:`ViewItem`, :class:`int`, :class:`str`]] + Add the new item **after** the specified item. If an :class:`int` is provided, the item will be detected by ``id``, otherwise by ``custom_id``. + into: Optional[Union[:class:`ViewItem`, :class:`int`, :class:`str`]] + Add the new item **into** the specified item. This would be equivalent to `into.add_item(item)`, where `into` is a :class:`ViewItem`. + If an :class:`int` is provided, the item will be detected by ``id``, otherwise by ``custom_id``. Raises ------ @@ -886,8 +929,17 @@ def add_item(self, item: ViewItem[V]) -> Self: ValueError Maximum number of items has been exceeded (40) """ - - if isinstance(item._underlying, (SelectComponent, ButtonComponent)): + if ( + before + and after + or before + and (index is not None) + or after + and (index is not None) + ): + raise ValueError("Can only specify one of before, after, and index.") + + if isinstance(item.underlying, (SelectComponent, ButtonComponent)): raise ValueError( f"cannot add Select or Button to DesignerView directly. Use ActionRow instead." ) @@ -895,6 +947,198 @@ def add_item(self, item: ViewItem[V]) -> Self: super().add_item(item) return self + def replace_item( + self, original_item: ViewItem[V] | str | int, new_item: ViewItem[V] + ) -> Self: + """Directly replace an item in this view. + If an :class:`int` is provided, the item will be replaced by ``id``, otherwise by ``custom_id``. + + Parameters + ---------- + original_item: Union[:class:`ViewItem`, :class:`int`, :class:`str`] + The item, item ``id``, or item ``custom_id`` to replace in the view. + new_item: :class:`ViewItem` + The new item to insert into the view. + + Returns + ------- + :class:`BaseView` + The view instance. + """ + + if not isinstance(new_item, ViewItem): + raise TypeError(f"expected ViewItem not {new_item.__class__!r}") + + if isinstance(original_item, (str, int)): + original_item = self.get_item(original_item) + if not original_item: + raise ValueError(f"Could not find original_item in view.") + try: + if original_item.parent is self: + i = self.children.index(original_item) + new_item.parent = self + self.children[i] = new_item + original_item.parent = None + else: + original_item.parent.replace_item(original_item, new_item) + except ValueError: + raise ValueError(f"Could not find original_item in view.") + return self + + def add_row( + self, + *items: ViewItem[V], + id: int | None = None, + ) -> Self: + """Adds an :class:`ActionRow` to the view. + + To append a pre-existing :class:`ActionRow`, use :meth:`add_item` instead. + + Parameters + ---------- + *items: Union[:class:`Button`, :class:`Select`] + The items this action row contains. + id: Optiona[:class:`int`] + The action row's ID. + """ + from .action_row import ActionRow + + row = ActionRow(*items, id=id) + + return self.add_item(row) + + def add_container( + self, + *items: ViewItem[V], + id: int | None = None, + ) -> Self: + """Adds a :class:`Container` to the view. + + To append a pre-existing :class:`Container`, use the + :meth:`add_item` method, instead. + + Parameters + ---------- + *items: :class:`ViewItem` + The items contained in this container. + accessory: Optional[:class:`ViewItem`] + id: Optional[:class:`int`] + The container's ID. + """ + from .container import Container + + container = Container(*items, id=id) + + return self.add_item(container) + + def add_section( + self, + *items: ViewItem[V], + accessory: ViewItem[V], + id: int | None = None, + ) -> Self: + """Adds a :class:`Section` to the view. + + To append a pre-existing :class:`Section`, use the + :meth:`add_item` method, instead. + + Parameters + ---------- + *items: :class:`ViewItem` + The items contained in this section, up to 3. + Currently only supports :class:`~discord.ui.TextDisplay`. + accessory: Optional[:class:`ViewItem`] + The section's accessory. This is displayed in the top right of the section. + Currently only supports :class:`~discord.ui.Button` and :class:`~discord.ui.Thumbnail`. + id: Optional[:class:`int`] + The section's ID. + """ + from .section import Section + + section = Section(*items, accessory=accessory, id=id) + + return self.add_item(section) + + def add_text(self, content: str, id: int | None = None) -> Self: + """Adds a :class:`TextDisplay` to the view. + + Parameters + ---------- + content: :class:`str` + The content of the TextDisplay + id: Optiona[:class:`int`] + The text displays' ID. + """ + from .text_display import TextDisplay + + text = TextDisplay(content, id=id) + + return self.add_item(text) + + def add_gallery( + self, + *items: MediaGalleryItem, + id: int | None = None, + ) -> Self: + """Adds a :class:`MediaGallery` to the view. + + To append a pre-existing :class:`MediaGallery`, use :meth:`add_item` instead. + + Parameters + ---------- + *items: :class:`MediaGalleryItem` + The media this gallery contains. + id: Optiona[:class:`int`] + The gallery's ID. + """ + from .media_gallery import MediaGallery + + g = MediaGallery(*items, id=id) + + return self.add_item(g) + + def add_file(self, url: str, spoiler: bool = False, id: int | None = None) -> Self: + """Adds a :class:`TextDisplay` to the view. + + Parameters + ---------- + url: :class:`str` + The URL of this file's media. This must be an ``attachment://`` URL that references a :class:`~discord.File`. + spoiler: Optional[:class:`bool`] + Whether the file has the spoiler overlay. Defaults to ``False``. + id: Optiona[:class:`int`] + The file's ID. + """ + from .file import File + + f = File(url, spoiler=spoiler, id=id) + + return self.add_item(f) + + def add_separator( + self, + *, + divider: bool = True, + spacing: SeparatorSpacingSize = SeparatorSpacingSize.small, + id: int | None = None, + ) -> Self: + """Adds a :class:`Separator` to the container. + + Parameters + ---------- + divider: :class:`bool` + Whether the separator is a divider. Defaults to ``True``. + spacing: :class:`~discord.SeparatorSpacingSize` + The spacing size of the separator. Defaults to :attr:`~discord.SeparatorSpacingSize.small`. + id: Optional[:class:`int`] + The separator's ID. + """ + from .separator import Separator + + s = Separator(divider=divider, spacing=spacing, id=id) + + return self.add_item(s) + def refresh(self, components: list[Component]): # Refreshes view data using discord's values # Assumes the components and items are identical