diff --git a/.github/workflows/grace.yml b/.github/workflows/grace.yml index 6b415d29..692e4728 100644 --- a/.github/workflows/grace.yml +++ b/.github/workflows/grace.yml @@ -22,13 +22,13 @@ jobs: - name: Set up Python 3.10 uses: actions/setup-python@v3 with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip pip install flake8 pytest - pip install . + pip install .[dev] - name: Lint with flake8 run: | diff --git a/bot/__init__.py b/bot/__init__.py index 9e603bc7..e0276873 100755 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,12 +1,12 @@ -from grace.application import Application from apscheduler.schedulers.asyncio import AsyncIOScheduler +from grace.application import Application try: - from nltk.downloader import Downloader from nltk import download, download_shell + from nltk.downloader import Downloader - download('vader_lexicon', quiet=True) + download("vader_lexicon", quiet=True) except ModuleNotFoundError: print("nltk module not properly installed") @@ -17,9 +17,9 @@ def _create_bot(app): Import is deferred to avoid circular dependency. """ from bot.grace import Grace + return Grace(app) app = Application() bot = _create_bot(app) - diff --git a/bot/classes/recurrence.py b/bot/classes/recurrence.py index 065cd3db..6f39adda 100644 --- a/bot/classes/recurrence.py +++ b/bot/classes/recurrence.py @@ -10,4 +10,3 @@ class Recurrence(Enum): def __str__(self): return self.name.capitalize() - diff --git a/bot/classes/state.py b/bot/classes/state.py index a5d4616a..1a3a2a64 100644 --- a/bot/classes/state.py +++ b/bot/classes/state.py @@ -1,11 +1,10 @@ -from enum import Enum, unique +from enum import IntEnum, unique @unique -class State(Enum): +class State(IntEnum): DISABLED = 0 ENABLED = 1 def __str__(self): return self.name.capitalize() - diff --git a/bot/extensions/bookmark_cog.py b/bot/extensions/bookmark_cog.py index ce4e1577..5dc6fe5a 100644 --- a/bot/extensions/bookmark_cog.py +++ b/bot/extensions/bookmark_cog.py @@ -1,56 +1,58 @@ from typing import List -from discord import Embed, Message, Interaction, File + +from discord import Embed, File, Interaction, Message from discord.app_commands import ContextMenu from discord.ext.commands import Cog + from bot.grace import Grace class BookmarkCog(Cog): - def __init__(self, bot: Grace) -> None: - self.bot: Grace = bot - - save_message_ctx_menu: ContextMenu = ContextMenu( - name='Save Message', - callback=self.save_message - ) - - self.bot.tree.add_command(save_message_ctx_menu) - - async def get_message_files(self, message: Message) -> List[File]: - """Fetch files from the message attachments - - :param message: Message to fetch files from - :type message: Message - - :return: List of files - :rtype: List[File] - """ - return list(map(lambda attachment: attachment.to_file(), message.attachments)) - - async def save_message(self, interaction: Interaction, message: Message) -> None: - """Saves the message - - :param interaction: ContextMenu command interaction - :type interaction: Interaction - :param message: Message of the interaction - :type message: Message - """ - sent_at: int = int(message.created_at.timestamp()) - files: List[File] = await self.get_message_files(message) - - save_embed: Embed = Embed( - title='Bookmark Info', - color=self.bot.default_color - ) - - save_embed.add_field(name="Sent By", value=message.author, inline=False) - save_embed.add_field(name="Sent At", value=f'', inline=False) - save_embed.add_field(name="Original Message", value=f'[Jump]({message.jump_url})', inline=False) - - await interaction.user.send(embed=save_embed) - await interaction.user.send(message.content, embeds=message.embeds, files=files) - await interaction.response.send_message("Message successfully saved.", ephemeral=True) + def __init__(self, bot: Grace) -> None: + self.bot: Grace = bot + + save_message_ctx_menu: ContextMenu = ContextMenu( + name="Save Message", callback=self.save_message + ) + + self.bot.tree.add_command(save_message_ctx_menu) + + async def get_message_files(self, message: Message) -> List[File]: + """Fetch files from the message attachments + + :param message: Message to fetch files from + :type message: Message + + :return: List of files + :rtype: List[File] + """ + return list(map(lambda attachment: attachment.to_file(), message.attachments)) + + async def save_message(self, interaction: Interaction, message: Message) -> None: + """Saves the message + + :param interaction: ContextMenu command interaction + :type interaction: Interaction + :param message: Message of the interaction + :type message: Message + """ + sent_at: int = int(message.created_at.timestamp()) + files: List[File] = await self.get_message_files(message) + + save_embed: Embed = Embed(title="Bookmark Info", color=self.bot.default_color) + + save_embed.add_field(name="Sent By", value=message.author, inline=False) + save_embed.add_field(name="Sent At", value=f"", inline=False) + save_embed.add_field( + name="Original Message", value=f"[Jump]({message.jump_url})", inline=False + ) + + await interaction.user.send(embed=save_embed) + await interaction.user.send(message.content, embeds=message.embeds, files=files) + await interaction.response.send_message( + "Message successfully saved.", ephemeral=True + ) async def setup(bot: Grace) -> None: - await bot.add_cog(BookmarkCog(bot)) + await bot.add_cog(BookmarkCog(bot)) diff --git a/bot/extensions/code_generator_cog.py b/bot/extensions/code_generator_cog.py index 5194dd10..564f61dd 100644 --- a/bot/extensions/code_generator_cog.py +++ b/bot/extensions/code_generator_cog.py @@ -1,15 +1,37 @@ -from discord.ext.commands import Cog, hybrid_command -from discord.app_commands import Choice, autocomplete +import openai from discord import Embed, Interaction +from discord.app_commands import Choice, autocomplete +from discord.ext.commands import Cog, hybrid_command from openai.api_resources.completion import Completion -import openai -from lib.config_required import cog_config_required +from lib.config_required import cog_config_required LANGUAGES = [ - "Python", "C", "C++", "Java", "Csharp", "R", "Ruby", "JavaScript", "Swift", - "Go", "Kotlin", "Rust", "PHP", "ObjectiveC", "SQL", "Lisp", "Perl", - "Haskell", "Erlang", "Scala", "Clojure", "Julia", "Elixir", "F#", "Bash" + "Python", + "C", + "C++", + "Java", + "Csharp", + "R", + "Ruby", + "JavaScript", + "Swift", + "Go", + "Kotlin", + "Rust", + "PHP", + "ObjectiveC", + "SQL", + "Lisp", + "Perl", + "Haskell", + "Erlang", + "Scala", + "Clojure", + "Julia", + "Elixir", + "F#", + "Bash", ] @@ -25,25 +47,31 @@ async def language_autocomplete(_: Interaction, current: str) -> list[Choice[str """ return [ Choice(name=lang.capitalize(), value=lang.capitalize()) - for lang in LANGUAGES if current.lower() in lang.lower() + for lang in LANGUAGES + if current.lower() in lang.lower() ] -@cog_config_required("openai", "api_key", "Generate yours [here](https://beta.openai.com/account/api-keys)") +@cog_config_required( + "openai", + "api_key", + "Generate yours [here](https://beta.openai.com/account/api-keys)", +) class CodeGenerator( Cog, name="OpenAI", - description="Generate code using OpenAI API by providing a comment and language." + description="Generate code using OpenAI API by providing a comment and language.", ): """A Cog that generate code using text.""" + def __init__(self, bot): self.bot = bot self.api_key = self.required_config @hybrid_command( - name='code', - help='Generate code by providing a comment and language.', - usage="language={programming_language} comment={sentence}" + name="code", + help="Generate code by providing a comment and language.", + usage="language={programming_language} comment={sentence}", ) @autocomplete(language=language_autocomplete) async def code_generator(self, ctx, *, language: str, comment: str) -> None: @@ -77,7 +105,7 @@ async def code_generator(self, ctx, *, language: str, comment: str) -> None: embed.add_field( name=comment.capitalize(), value=f"```{language}{code_generated}``` {ctx.author} | {language}", - inline=False + inline=False, ) await ctx.send(embed=embed) diff --git a/bot/extensions/color_cog.py b/bot/extensions/color_cog.py index 2e965a20..5ced3e0e 100644 --- a/bot/extensions/color_cog.py +++ b/bot/extensions/color_cog.py @@ -1,16 +1,24 @@ import os +from typing import Tuple, Union + +from discord import Color, Embed, File +from discord.ext.commands import ( + Cog, + CommandInvokeError, + Context, + HybridCommandError, + hybrid_group, +) from PIL import Image -from discord.ext.commands import Cog, hybrid_group, HybridCommandError, CommandInvokeError, Context -from discord import Embed, File, Color + from bot.helpers.error_helper import send_command_error -from typing import Union, Tuple def get_embed_color(color: Union[Tuple[int, int, int], str]) -> Color: """Convert a color to an Embed Color object. - - :param color: A tuple of 3 integers in the range 0-255 representing an RGB - color, or a string in the format '#RRGGBB' representing a + + :param color: A tuple of 3 integers in the range 0-255 representing an RGB + color, or a string in the format '#RRGGBB' representing a hexadecimal color. :type color: Union[Tuple[int, int, int], str] :return: An Embed Color object representing the input color. @@ -21,16 +29,19 @@ def get_embed_color(color: Union[Tuple[int, int, int], str]) -> Color: return Color.from_str(color) -class ColorCog(Cog, name="Color", description="Collection of commands to bring color in your life."): +class ColorCog( + Cog, name="Color", description="Collection of commands to bring color in your life." +): """A Discord Cog that provides a set of commands to display colors.""" + def __init__(self, bot): self.bot = bot @hybrid_group(name="color", help="Commands to bring color in your life") async def color_group(self, ctx: Context) -> None: - """Group command for the color commands. If called without a subcommand, + """Group command for the color commands. If called without a subcommand, it sends the help message. - + :param ctx: The context of the command invocation. :type ctx: Context """ @@ -39,47 +50,49 @@ async def color_group(self, ctx: Context) -> None: @color_group.group(name="show", help="Commands to display colors.") async def show_group(self, ctx: Context) -> None: - """Group command for the show subcommands. If called without a subcommand, + """Group command for the show subcommands. If called without a subcommand, it sends the help message. - + :param ctx: The context of the command invocation. :type ctx: Context """ if ctx.invoked_subcommand is None: await ctx.send_help(ctx.command) - async def display_color(self, ctx: Context, color: Union[Tuple[int, int, int], str]) -> None: + async def display_color( + self, ctx: Context, color: Union[Tuple[int, int, int], str] + ) -> None: """Display a color in an embed message. - + :param ctx: The context of the command invocation. :type ctx: Context - :param color: A tuple of 3 integers in the range 0-255 representing an - RGB color, or a string in the format '#RRGGBB' representing + :param color: A tuple of 3 integers in the range 0-255 representing an + RGB color, or a string in the format '#RRGGBB' representing a hexadecimal color. :type color: Union[Tuple[int, int, int], str] """ - colored_image = Image.new('RGB', (200, 200), color) - colored_image.save('color.png') - file = File('color.png') + colored_image = Image.new("RGB", (200, 200), color) + colored_image.save("color.png") + file = File("color.png") embed = Embed( color=get_embed_color(color), - title='Here goes your color!', - description=f"{color}" + title="Here goes your color!", + description=f"{color}", ) embed.set_image(url="attachment://color.png") await ctx.send(embed=embed, file=file) - os.remove('color.png') + os.remove("color.png") @show_group.command( - name='rgb', + name="rgb", help="Displays the RGB color entered by the user.", - usage="color show rgb {red integer} {green integer} {blue integer}" + usage="color show rgb {red integer} {green integer} {blue integer}", ) async def rgb_command(self, ctx: Context, r: int, g: int, b: int) -> None: """Display an RGB color in an embed message. - + :param ctx: The context of the command invocation. :type ctx: Context :param r: The red component of the color (0-255). @@ -93,46 +106,54 @@ async def rgb_command(self, ctx: Context, r: int, g: int, b: int) -> None: @rgb_command.error async def rgb_command_error(self, ctx: Context, error: Exception) -> None: - """Event listener for errors that occurred during the execution of the + """Event listener for errors that occurred during the execution of the 'rgb' command. It sends an error message to the user. - + :param ctx: The context of the command invocation. :type ctx: Context :param error: The error that was raised during command execution. :type error: Exception """ - if isinstance(error, HybridCommandError) or isinstance(error, CommandInvokeError): - await send_command_error(ctx, "Expected rgb color", ctx.command, "244 195 8") + if isinstance(error, HybridCommandError) or isinstance( + error, CommandInvokeError + ): + await send_command_error( + ctx, "Expected rgb color", ctx.command, "244 195 8" + ) @show_group.command( - name='hex', + name="hex", help="Displays the color of the hexcode entered by the user.", - usage="color show hex {hexadecimal string}" + usage="color show hex {hexadecimal string}", ) async def hex_command(self, ctx: Context, hex_code: str) -> None: """Display a color in an embed message using a hexadecimal color code. - + :param ctx: The context of the command invocation. :type ctx: Context :param hex_code: A string in the format '#RRGGBB' representing a hexadecimal color. :type hex_code: str """ - if not hex_code.startswith('#'): - hex_code = f'#{hex_code}' + if not hex_code.startswith("#"): + hex_code = f"#{hex_code}" await self.display_color(ctx, hex_code) @hex_command.error async def hex_command_error(self, ctx: Context, error: Exception) -> None: - """Event listener for errors that occurred during the execution of the + """Event listener for errors that occurred during the execution of the 'hex' command. It sends an error message to the user. - + :param ctx: The context of the command invocation. :type ctx: Context :param error: The error that was raised during command execution. :type error: Exception """ - if isinstance(error, HybridCommandError) or isinstance(error, CommandInvokeError): - await send_command_error(ctx, "Expected hexadecimal color", ctx.command, "#F4C308") + if isinstance(error, HybridCommandError) or isinstance( + error, CommandInvokeError + ): + await send_command_error( + ctx, "Expected hexadecimal color", ctx.command, "#F4C308" + ) async def setup(bot): diff --git a/bot/extensions/command_error_handler.py b/bot/extensions/command_error_handler.py index e13e654d..5eed9339 100644 --- a/bot/extensions/command_error_handler.py +++ b/bot/extensions/command_error_handler.py @@ -1,26 +1,33 @@ from datetime import timedelta from logging import warning -from discord.ext.commands import Cog, \ - MissingRequiredArgument, \ - CommandNotFound, \ - MissingPermissions, \ - CommandOnCooldown, \ - DisabledCommand, HybridCommandError, Context -from bot.helpers.error_helper import send_error from typing import Any, Coroutine, Optional + from discord import Interaction +from discord.ext.commands import ( + Cog, + CommandNotFound, + CommandOnCooldown, + Context, + DisabledCommand, + HybridCommandError, + MissingPermissions, + MissingRequiredArgument, +) + +from bot.helpers.error_helper import send_error from lib.config_required import MissingRequiredConfigError class CommandErrorHandler(Cog): """A Discord Cog that listens for command errors and sends an appropriate message to the user.""" + def __init__(self, bot): self.bot = bot @Cog.listener("on_command_error") async def get_command_error(self, ctx: Context, error: Exception) -> None: """Event listener for command errors. It logs the error and sends an appropriate message to the user. - + :param ctx: The context of the command invocation. :type ctx: Context :param error: The error that was raised during command execution. @@ -33,9 +40,14 @@ async def get_command_error(self, ctx: Context, error: Exception) -> None: elif isinstance(error, MissingRequiredConfigError): await send_error(ctx, error) elif isinstance(error, MissingPermissions): - await send_error(ctx, "You don't have the authorization to use that command.") + await send_error( + ctx, "You don't have the authorization to use that command." + ) elif isinstance(error, CommandOnCooldown): - await send_error(ctx, f"You're on Cooldown, wait {timedelta(seconds=int(error.retry_after))}") + await send_error( + ctx, + f"You're on Cooldown, wait {timedelta(seconds=int(error.retry_after))}", + ) elif isinstance(error, DisabledCommand): await send_error(ctx, "This command is disabled.") elif isinstance(error, MissingRequiredArgument): @@ -44,23 +56,27 @@ async def get_command_error(self, ctx: Context, error: Exception) -> None: await self.get_app_command_error(ctx.interaction, error) @Cog.listener("on_app_command_error") - async def get_app_command_error(self, interaction: Optional[Interaction], _: Exception) -> None: - """Event listener for command errors that occurred during an interaction. + async def get_app_command_error( + self, interaction: Optional[Interaction], _: Exception + ) -> None: + """Event listener for command errors that occurred during an interaction. It sends an error message to the user. - + :param interaction: The interaction where the error occurred. :type interaction: Interaction :param _ : The error that was raised during command execution. :type _: Exception """ if interaction and interaction.is_expired(): - await interaction.response.send_message("Interaction failed, please try again later!", ephemeral=True) + await interaction.response.send_message( + "Interaction failed, please try again later!", ephemeral=True + ) def send_command_help(ctx: Context) -> Coroutine[Any, Any, Any]: - """Send the help message for the command that raised an error, or + """Send the help message for the command that raised an error, or the general help message if no specific command was involved. - + :param ctx: The context of the command invocation. :type ctx: The context :return: The help message. @@ -72,4 +88,4 @@ def send_command_help(ctx: Context) -> Coroutine[Any, Any, Any]: async def setup(bot): - await bot.add_cog(CommandErrorHandler(bot)) \ No newline at end of file + await bot.add_cog(CommandErrorHandler(bot)) diff --git a/bot/extensions/extension_cog.py b/bot/extensions/extension_cog.py index a92485f6..ecead7db 100644 --- a/bot/extensions/extension_cog.py +++ b/bot/extensions/extension_cog.py @@ -1,11 +1,20 @@ +from typing import List + from discord import Embed from discord.app_commands import Choice, autocomplete -from discord.ext.commands import Cog, has_permissions, ExtensionAlreadyLoaded, ExtensionNotLoaded, hybrid_group, Context +from discord.ext.commands import ( + Cog, + Context, + ExtensionAlreadyLoaded, + ExtensionNotLoaded, + has_permissions, + hybrid_group, +) from emoji import emojize + from bot.classes.state import State from bot.extensions.command_error_handler import send_command_help from bot.models.extension import Extension -from typing import List def extension_autocomplete(state: bool): @@ -15,6 +24,7 @@ def extension_autocomplete(state: bool): :type state: bool :return: An autocomplete function. """ + async def inner_autocomplete(_, current: str) -> List[Choice]: """Autocomplete function for extensions. @@ -23,6 +33,7 @@ async def inner_autocomplete(_, current: str) -> List[Choice]: :return: A list of `Choice` objects for autocompleting the extension names. :rtype: List[Choice] """ + def create_choice(extension: Extension) -> Choice: """Creates a `Choice` object for the provided `extension`. @@ -31,18 +42,36 @@ def create_choice(extension: Extension) -> Choice: :return: A `Choice` object for the provided `extension`. :rtype: Choice """ - state_emoji = emojize(':green_circle:') if extension.is_enabled() else emojize(':red_circle:') - return Choice(name=f"{state_emoji} {extension.name}", value=extension.module_name) - return list(map(create_choice, Extension.by_state(state).filter(Extension.module_name.ilike(f"%{current}%")))) + state_emoji = ( + emojize(":green_circle:") + if extension.is_enabled() + else emojize(":red_circle:") + ) + return Choice( + name=f"{state_emoji} {extension.name}", value=extension.module_name + ) + + return list( + map( + create_choice, + Extension.by_state(state) + .where(Extension.module_name.ilike(f"%{current}%")) + .all(), + ) + ) + return inner_autocomplete class ExtensionCog(Cog, name="Extensions", description="Extensions managing cog"): """A `Cog` for managing extensions.""" + def __init__(self, bot): self.bot = bot - @hybrid_group(name="extension", aliases=["ext", "e"], help="Commands to manage extensions") + @hybrid_group( + name="extension", aliases=["ext", "e"], help="Commands to manage extensions" + ) @has_permissions(administrator=True) async def extension_group(self, ctx: Context) -> None: """The command group for managing extensions. @@ -53,28 +82,31 @@ async def extension_group(self, ctx: Context) -> None: if ctx.invoked_subcommand is None: await send_command_help(ctx) - @extension_group.command(name="list", aliases=["l"], help="Display the list of extensions") + @extension_group.command( + name="list", aliases=["l"], help="Display the list of extensions" + ) @has_permissions(administrator=True) async def list_extensions_command(self, ctx: Context) -> None: """Display the list of extensions in an embed message, indicating their current state (enabled or disabled). - + :param ctx: The context in which the command was called. :type ctx: Context. """ extensions = Extension.all() - embed = Embed( - color=self.bot.default_color, - title="Extensions" - ) + embed = Embed(color=self.bot.default_color, title="Extensions") for extension in extensions: - state_emoji = emojize(':green_circle:') if extension.is_enabled() else emojize(':red_circle:') + state_emoji = ( + emojize(":green_circle:") + if extension.is_enabled() + else emojize(":red_circle:") + ) embed.add_field( name=f"{state_emoji} {extension.name}", value=f"**State**: {extension.state}", - inline=False + inline=False, ) if not extensions: @@ -82,18 +114,23 @@ async def list_extensions_command(self, ctx: Context) -> None: await ctx.send(embed=embed, ephemeral=True) - @extension_group.command(name="enable", aliases=["e"], help="Enable a given extension", usage="{extension_id}") + @extension_group.command( + name="enable", + aliases=["e"], + help="Enable a given extension", + usage="{extension_id}", + ) @has_permissions(administrator=True) @autocomplete(extension_name=extension_autocomplete(State.DISABLED)) async def enable_extension_command(self, ctx: Context, extension_name: str) -> None: """Enable a given extension by its module name. - + :param ctx: The context in which the command was called. :type ctx: Context. :param extension_name: The module name of the extension to enable. :type extension_name: str """ - extension = Extension.get_by(module_name=extension_name) + extension = Extension.find_by(module_name=extension_name) if extension: try: @@ -107,18 +144,25 @@ async def enable_extension_command(self, ctx: Context, extension_name: str) -> N else: await ctx.send(f"Extension **{extension_name}** not found", ephemeral=True) - @extension_group.command(name="disable", aliases=["d"], help="Disable a given extension", usage="{extension_id}") + @extension_group.command( + name="disable", + aliases=["d"], + help="Disable a given extension", + usage="{extension_id}", + ) @has_permissions(administrator=True) @autocomplete(extension_name=extension_autocomplete(State.ENABLED)) - async def disable_extension_command(self, ctx: Context, extension_name: str) -> None: + async def disable_extension_command( + self, ctx: Context, extension_name: str + ) -> None: """Disable a given extension by its module name. - + :param ctx: The context in which the command was called. :type ctx: Context. :param extension_name: The module name of the extension to disable. :type extension_name: str """ - extension = Extension.get_by(module_name=extension_name) + extension = Extension.find_by(module_name=extension_name) if extension: try: diff --git a/bot/extensions/fun_cog.py b/bot/extensions/fun_cog.py index 78bfb8d0..46e966c2 100644 --- a/bot/extensions/fun_cog.py +++ b/bot/extensions/fun_cog.py @@ -1,38 +1,46 @@ from json import loads +from random import choice as random_choice + +from discord import Embed +from discord.ext.commands import Cog, Context, cooldown, hybrid_group from discord.ext.commands.cooldowns import BucketType -from discord.ext.commands import Cog, cooldown, hybrid_group, Context -from discord import Embed, Colour from requests import get -from random import choice as random_choice + from bot.extensions.command_error_handler import send_command_help from bot.models.extensions.fun.answer import Answer class FunCog(Cog, name="Fun", description="Collection of fun commands"): """A cog containing fun commands.""" + def __init__(self, bot): self.bot = bot self.goosed_gif_links = [ - 'https://media.tenor.com/XG_ZOTYukysAAAAC/goose.gif', - 'https://media.tenor.com/pSnSQRfiIP8AAAAd/birds-kid.gif', - 'https://media.tenor.com/GDkgAup55_0AAAAC/duck-bite.gif' + "https://media.tenor.com/XG_ZOTYukysAAAAC/goose.gif", + "https://media.tenor.com/pSnSQRfiIP8AAAAd/birds-kid.gif", + "https://media.tenor.com/GDkgAup55_0AAAAC/duck-bite.gif", ] @hybrid_group(name="fun", help="Fun commands") async def fun_group(self, ctx: Context) -> None: """Group of fun commands. - + :param ctx: The context in which the command was called. :type ctx: Context """ if ctx.invoked_subcommand is None: await send_command_help(ctx) - @fun_group.command(name='eightball', aliases=['8ball'], help="Ask a question and be answered.", usage="{question}") + @fun_group.command( + name="eightball", + aliases=["8ball"], + help="Ask a question and be answered.", + usage="{question}", + ) @cooldown(4, 30, BucketType.user) async def eightball_command(self, ctx: Context, question: str) -> None: """Ask a question and get an answer. - + :param ctx: The context in which the command was called. :type ctx: Context :param question: The question asked by the user. @@ -44,38 +52,40 @@ async def eightball_command(self, ctx: Context, question: str) -> None: answer = "You need to ask me a question!" answer_embed = Embed( - title=f'{ctx.author.name}, Grace says: ', + title=f"{ctx.author.name}, Grace says: ", color=self.bot.default_color, description=answer.answer, ) await ctx.send(embed=answer_embed) - @fun_group.command(name='goosed', help='Go goose yourself') + @fun_group.command(name="goosed", help="Go goose yourself") async def goose_command(self, ctx: Context) -> None: """Send a Goose image. - + :param ctx: The context in which the command was called. :type ctx: Context """ goosed_embed = Embed( color=self.bot.default_color, - title='**GET GOOSED**', + title="**GET GOOSED**", ) goosed_embed.set_image(url=random_choice(self.goosed_gif_links)) await ctx.send(embed=goosed_embed) - @fun_group.command(name='quote', help='Sends an inspirational quote') + @fun_group.command(name="quote", help="Sends an inspirational quote") async def quote_command(self, ctx: Context) -> None: """Generate a random inspirational quote. - + :param ctx: The context in which the command was called. :type ctx: Context """ - response = get('https://api.forismatic.com/api/1.0/?method=getQuote&format=json&lang=en') + response = get( + "https://api.forismatic.com/api/1.0/?method=getQuote&format=json&lang=en" + ) if response.ok: - quote = '{quoteText} \n-- {quoteAuthor}'.format(**loads(response.text)) + quote = "{quoteText} \n-- {quoteAuthor}".format(**loads(response.text)) embed = Embed( color=self.bot.default_color, diff --git a/bot/extensions/grace_cog.py b/bot/extensions/grace_cog.py index ac40367b..24b08e85 100755 --- a/bot/extensions/grace_cog.py +++ b/bot/extensions/grace_cog.py @@ -1,13 +1,18 @@ -from discord.ext.commands import Cog, hybrid_command, Context -from discord.ui import Button -from discord.app_commands import Choice, autocomplete from discord import Embed, Interaction +from discord.app_commands import Choice, autocomplete +from discord.ext.commands import Cog, Context, hybrid_command +from discord.ui import Button from emoji import emojize -from lib.config_required import command_config_required -from lib.paged_embeds import PagedEmbedView + from bot.helpers import send_error -from bot.helpers.github_helper import create_contributors_embeds, create_repository_button, available_project_names +from bot.helpers.github_helper import ( + available_project_names, + create_contributors_embeds, + create_repository_button, +) from bot.services.github_service import GithubService +from lib.config_required import command_config_required +from lib.paged_embeds import PagedEmbedView async def project_autocomplete(_: Interaction, current: str) -> list[Choice[str]]: @@ -22,25 +27,27 @@ async def project_autocomplete(_: Interaction, current: str) -> list[Choice[str] """ return [ Choice(name=project, value=project) - for project in available_project_names() if current.lower() in project.lower() + for project in available_project_names() + if current.lower() in project.lower() ] class GraceCog(Cog, name="Grace", description="Default grace commands"): """A cog that contains default commands for the Grace bot.""" + __CODE_SOCIETY_WEBSITE_BUTTON = Button( emoji=emojize(":globe_with_meridians:"), label="Website", - url="https://codesociety.xyz" + url="https://codesociety.xyz", ) def __init__(self, bot): self.bot = bot - @hybrid_command(name='info', help='Show information about the bot') + @hybrid_command(name="info", help="Show information about the bot") async def info_command(self, ctx: Context, ephemeral=True) -> None: """Show information about the bot. - + :param ctx: The context in which the command was called. :type ctx: Context :param ephemeral: A flag indicating whether the message should be sent as an ephemeral message. Default is True. @@ -58,25 +65,23 @@ async def info_command(self, ctx: Context, ephemeral=True) -> None: info_embed.add_field( name="Fun fact about me", value="I'm named after [Grace Hopper](https://en.wikipedia.org/wiki/Grace_Hopper) {emojize(':rabbit:')}", - inline=False + inline=False, ) info_embed.add_field( name=f"{emojize(':test_tube:')} Code Society Lab", value="Contribute to our [projects](https://github.com/Code-Society-Lab/grace)\n", - inline=True + inline=True, ) info_embed.add_field( name=f"{emojize(':crossed_swords:')} Codewars", value="Set your clan to **CodeSoc**\n", - inline=True + inline=True, ) info_embed.add_field( - name="Need help?", - value=f"Send '{ctx.prefix}help'", - inline=False + name="Need help?", value=f"Send '{ctx.prefix}help'", inline=False ) view = PagedEmbedView([info_embed]) @@ -91,10 +96,10 @@ async def info_command(self, ctx: Context, ephemeral=True) -> None: await view.send(ctx, ephemeral=ephemeral) - @hybrid_command(name='ping', help='Shows the bot latency') + @hybrid_command(name="ping", help="Shows the bot latency") async def ping_command(self, ctx: Context) -> None: """Show the bot latency. - + :param ctx: The context in which the command was called. :type ctx: Context """ @@ -105,10 +110,10 @@ async def ping_command(self, ctx: Context) -> None: await ctx.send(embed=embed) - @hybrid_command(name='hopper', help='The legend of Grace Hopper') + @hybrid_command(name="hopper", help="The legend of Grace Hopper") async def hopper_command(self, ctx: Context) -> None: """Show a link to a comic about Grace Hopper. - + :param ctx: The context in which the command was called. :type ctx: Context :return: None @@ -116,11 +121,14 @@ async def hopper_command(self, ctx: Context) -> None: await ctx.send("https://www.smbc-comics.com/?id=2516") @command_config_required("github", "api_key") - @hybrid_command(name="contributors", description="Show a list of Code Society Lab's contributors") + @hybrid_command( + name="contributors", + description="Show a list of Code Society Lab's contributors", + ) @autocomplete(project=project_autocomplete) async def contributors(self, ctx: Context, project: str) -> None: """Show a list of contributors for the Code Society Lab repositories. - + :param ctx: The context in which the command was called. :type ctx: Context :param project: The project's name to get contributors. diff --git a/bot/extensions/language_cog.py b/bot/extensions/language_cog.py index f281fa2d..ab588e5b 100755 --- a/bot/extensions/language_cog.py +++ b/bot/extensions/language_cog.py @@ -1,8 +1,10 @@ -from discord.ext.commands import Cog, has_permissions, hybrid_group, Context -from discord import Message, Embed from logging import warning -from nltk.tokenize import TweetTokenizer + +from discord import Embed, Message +from discord.ext.commands import Cog, Context, has_permissions, hybrid_group from nltk.sentiment.vader import SentimentIntensityAnalyzer +from nltk.tokenize import TweetTokenizer + from bot.models.extensions.language.trigger import Trigger @@ -35,23 +37,25 @@ def get_message_sentiment_polarity(self, message: Message) -> int: # negatively about something. We run the while message through vader and if the aggregated # score is ultimately negative, neutral, or positive sv = self.sid.polarity_scores(message.content) - if sv['neu'] + sv['pos'] < sv['neg'] or sv['pos'] == 0.0: - if sv['neg'] > sv['pos']: + if sv["neu"] + sv["pos"] < sv["neg"] or sv["pos"] == 0.0: + if sv["neg"] > sv["pos"]: return -1 return 0 - return 1; + return 1 async def name_react(self, message: Message) -> None: """ Checks message sentiment and if the sentiment is neutral or positive, react with a positive_emoji, otherwise react with negative_emoji """ - grace_trigger = Trigger.get_by(name="Grace") + grace_trigger = Trigger.find_by(name="Grace") if grace_trigger is None: - warning("Missing trigger entry for \"Grace\"") + warning('Missing trigger entry for "Grace"') return - if self.bot.user.mentioned_in(message) and not message.content.startswith('<@!'): + if self.bot.user.mentioned_in(message) and not message.content.startswith( + "<@!" + ): # Note: the trigger needs to have a None-condition now that it's generic if self.get_message_sentiment_polarity(message) >= 0: await message.add_reaction(grace_trigger.positive_emoji) @@ -67,35 +71,42 @@ async def penguin_react(self, message: Message) -> None: :param message: A discord message to check for references to our lord and savior. :type message: discord.Message """ - linus_trigger = Trigger.get_by(name="Linus") + linus_trigger = Trigger.find_by(name="Linus") if linus_trigger is None: - warning("Missing trigger entry for \"Linus\"") + warning('Missing trigger entry for "Linus"') return message_tokens = self.tokenizer.tokenize(message.content) tokenlist = list(map(lambda s: s.lower(), message_tokens)) - linustarget = [i for i, x in enumerate( - tokenlist) if x in linus_trigger.words] + linustarget = [i for i, x in enumerate(tokenlist) if x in linus_trigger.words] # Get the indices of all linuses in the message if linustarget: fail = False for linusindex in linustarget: try: - if tokenlist[linusindex + 1] == 'tech' and tokenlist[linusindex + 2] == 'tips': + if ( + tokenlist[linusindex + 1] == "tech" + and tokenlist[linusindex + 2] == "tips" + ): fail = True - elif tokenlist[linusindex + 1] == 'and' and tokenlist[linusindex + 2] == 'lucy': + elif ( + tokenlist[linusindex + 1] == "and" + and tokenlist[linusindex + 2] == "lucy" + ): fail = True except IndexError: pass - determined_sentiment_polarity = self.get_message_sentiment_polarity(message) + determined_sentiment_polarity = self.get_message_sentiment_polarity( + message + ) if not fail and determined_sentiment_polarity < 0: await message.add_reaction(linus_trigger.negative_emoji) return - fail = (determined_sentiment_polarity < 1) + fail = determined_sentiment_polarity < 1 if not fail: await message.add_reaction(linus_trigger.positive_emoji) @@ -104,10 +115,10 @@ async def penguin_react(self, message: Message) -> None: async def on_message(self, message: Message) -> None: """A listener function that calls the `penguin_react`, `name_react`, and `pun_react` functions when a message is received. - + :param message: The message that was received. :type message: discord.Message - """ + """ await self.penguin_react(message) await self.name_react(message) @@ -120,15 +131,15 @@ async def triggers_group(self, ctx) -> None: :type ctx: discord.ext.commands.Context """ if ctx.invoked_subcommand is None: - trigger = Trigger.get_by(name="Linus") + trigger = Trigger.find_by(name="Linus") if trigger is None: - warning("Missing trigger entry for \"Linus\"") + warning('Missing trigger entry for "Linus"') return embed = Embed( color=self.bot.default_color, - title=f"Triggers", - description="\n".join(trigger.words) + title="Triggers", + description="\n".join(trigger.words), ) await ctx.send(embed=embed) @@ -143,7 +154,7 @@ async def add_trigger_word(self, ctx: Context, new_word: str) -> None: :param new_word: The new trigger word to be added. :type new_word: str """ - trigger = Trigger.get_by(name="Linus") + trigger = Trigger.find_by(name="Linus") if trigger: if new_word in trigger.words: @@ -155,7 +166,9 @@ async def add_trigger_word(self, ctx: Context, new_word: str) -> None: else: await ctx.send(f"Unable to add **{new_word}**") - @triggers_group.command(name="remove", help="Remove a trigger word", usage="{old_word}") + @triggers_group.command( + name="remove", help="Remove a trigger word", usage="{old_word}" + ) @has_permissions(administrator=True) async def remove_trigger_word(self, ctx: Context, old_word: str) -> None: """Remove an existing trigger word. @@ -165,7 +178,7 @@ async def remove_trigger_word(self, ctx: Context, old_word: str) -> None: :param old_word: The trigger word to be removed. :type old_word: str """ - trigger = Trigger.get_by(name="Linus") + trigger = Trigger.find_by(name="Linus") if trigger: if old_word not in trigger.words: diff --git a/bot/extensions/mermaid_cog.py b/bot/extensions/mermaid_cog.py index 06014630..b2be7669 100644 --- a/bot/extensions/mermaid_cog.py +++ b/bot/extensions/mermaid_cog.py @@ -1,24 +1,19 @@ import re - from typing import Optional -from discord.ext.commands import Cog, command, Context from discord import Embed, Message +from discord.ext.commands import Cog, Context, command from bot.extensions.command_error_handler import send_command_help from bot.services.mermaid_service import generate_mermaid_diagram -class MermaidCog( - Cog, - name="Mermaid", - description="Generates mermaid diagrams" -): +class MermaidCog(Cog, name="Mermaid", description="Generates mermaid diagrams"): def __init__(self, bot): self.bot = bot self.mermaid_codeblock_pattern = r"```mermaid\n(.*?)```" self.codeblock_pattern = r"```(?:\w+)?\n(.*?)```" - + def generate_diagram_embed(self, diagram: str) -> Embed: """ Generate a Discord embed containing a Mermaid diagram image or error @@ -47,9 +42,9 @@ def generate_diagram_embed(self, diagram: str) -> Embed: return embed def extract_code_block( - self, - content: str, - require_mermaid_tag: bool = False, + self, + content: str, + require_mermaid_tag: bool = False, ) -> str: """Extracts the mermaid script from a code block. @@ -63,10 +58,10 @@ def extract_code_block( THIS IS A MERMAID CODE BLOCK ``` - + :param content: String from which the code block will be extracted :type content: str - :param require_mermaid_tag: Whether mermaid tag is required in a + :param require_mermaid_tag: Whether mermaid tag is required in a code block or not :type require_mermaid_tag: bool @@ -75,30 +70,24 @@ def extract_code_block( """ if require_mermaid_tag: if codeblock_match := re.search( - self.mermaid_codeblock_pattern, - content, - re.DOTALL + self.mermaid_codeblock_pattern, content, re.DOTALL ): return codeblock_match.group(1).strip() - elif codeblock_match := re.search( - self.codeblock_pattern, - content, - re.DOTALL - ): + elif codeblock_match := re.search(self.codeblock_pattern, content, re.DOTALL): return codeblock_match.group(1).strip() - return '' + return "" @command( name="mermaid", help="Generate a diagram from mermaid script", - usage="՝՝՝\nMermaid script goes here...\n՝՝՝" + usage="՝՝՝\nMermaid script goes here...\n՝՝՝", ) async def mermaid(self, ctx: Context, *, content: Optional[str]): """Generates a mermaid diagram Reply with this command to a message that contains a code block with - mermaid script to generate a diagram from it or attach a codeblock. + mermaid script to generate a diagram from it or attach a codeblock. :param ctx: Invocation context :type ctx: Context @@ -110,9 +99,7 @@ async def mermaid(self, ctx: Context, *, content: Optional[str]): if not ctx.message.reference and content: diagram = self.extract_code_block(content) elif ctx.message.reference and not content: - ref_msg = await ctx.channel.fetch_message( - ctx.message.reference.message_id - ) + ref_msg = await ctx.channel.fetch_message(ctx.message.reference.message_id) diagram = self.extract_code_block(ref_msg.content) if not diagram: @@ -134,16 +121,13 @@ async def on_message(self, message: Message): ctx = await self.bot.get_context(message) - # Making sure there're no messages referenced, and no mermaid command + # Making sure there're no messages referenced, and no mermaid command # being executed so that it doesn't overlap with the function that # executes the command if message.reference or ctx.command: return - diagram = self.extract_code_block( - message.content, - require_mermaid_tag=True - ) + diagram = self.extract_code_block(message.content, require_mermaid_tag=True) if diagram: await ctx.reply(embed=self.generate_diagram_embed(diagram)) @@ -158,10 +142,7 @@ async def on_message_edit(self, before: Message, after: Message): :param after: Edited message :type after: Message """ - diagram = self.extract_code_block( - after.content, - require_mermaid_tag=True - ) + diagram = self.extract_code_block(after.content, require_mermaid_tag=True) if diagram: ctx = await self.bot.get_context(after) await ctx.reply(embed=self.generate_diagram_embed(diagram)) diff --git a/bot/extensions/moderation_cog.py b/bot/extensions/moderation_cog.py index 2b8043e4..9bfefb2a 100644 --- a/bot/extensions/moderation_cog.py +++ b/bot/extensions/moderation_cog.py @@ -1,15 +1,19 @@ +from datetime import datetime +from logging import info from typing import Optional + +from discord import Member, Message, Reaction +from discord.ext.commands import Cog, Context, has_permissions, hybrid_command +from emoji import demojize + from bot import app -from logging import info -from discord import Message, Member, Reaction -from discord.ext.commands import Cog, has_permissions, hybrid_command, Context from bot.helpers.log_helper import danger, notice -from datetime import datetime -from emoji import demojize from bot.models.channel import Channel -class ModerationCog(Cog, name="Moderation", description="Collection of administrative commands."): +class ModerationCog( + Cog, name="Moderation", description="Collection of administrative commands." +): def __init__(self, bot): self.bot = bot @@ -17,9 +21,11 @@ def __init__(self, bot): def moderation_channel(self): return self.bot.get_channel_by_name("moderation_logs") - @hybrid_command(name='purge', help="Deletes n amount of messages.") + @hybrid_command(name="purge", help="Deletes n amount of messages.") @has_permissions(manage_messages=True) - async def purge(self, ctx: Context, limit: int, reason: Optional[str] = "No reason given") -> None: + async def purge( + self, ctx: Context, limit: int, reason: Optional[str] = "No reason given" + ) -> None: """Purge a specified number of messages from the channel. :param ctx: The context in which the command was called. @@ -31,7 +37,10 @@ async def purge(self, ctx: Context, limit: int, reason: Optional[str] = "No reas """ await ctx.defer() - log = danger("PURGE", f"{limit} message(s) purged by {ctx.author.mention} in {ctx.channel.mention}") + log = danger( + "PURGE", + f"{limit} message(s) purged by {ctx.author.mention} in {ctx.channel.mention}", + ) log.add_field("Reason", reason) await ctx.channel.purge(limit=int(limit) + 1, bulk=True, reason=reason) @@ -43,47 +52,65 @@ async def on_reaction_add(self, reaction: Reaction, member: Member) -> None: author: Member = message.author emojis = [":SOS_button:", ":red_question_mark:"] - is_already_reacted = any(filter(lambda r: r.me and demojize(r.emoji) in emojis and r.count > 0, message.reactions)) + is_already_reacted = any( + filter( + lambda r: r.me and demojize(r.emoji) in emojis and r.count > 0, + message.reactions, + ) + ) if author.bot or is_already_reacted: return None match demojize(str(reaction.emoji)): case ":SOS_button:": - await message.reply("[Don't ask to ask, just ask]()") + await message.reply( + "[Don't ask to ask, just ask]()" + ) case ":red_question_mark:": - guidelines: Channel = Channel.get_by(channel_name="posting_guidelines") - help: Channel = Channel.get_by(channel_name="help") + guidelines: Channel = Channel.find_by(channel_name="posting_guidelines") + help: Channel = Channel.find_by(channel_name="help") if guidelines and help: - await message.reply(f"If you need some help, read the <#{guidelines.channel_id}> and open a post in <#{help.channel_id}>!") + await message.reply( + f"If you need some help, read the <#{guidelines.channel_id}> and open a post in <#{help.channel_id}>!" + ) case _: return None # Grace also reacts and log the reaction because some people remove their reaction afterward await message.add_reaction(reaction) - log = notice("HELP REACTION", f"{member.mention} reacted to {message.jump_url} with {reaction.emoji}") + log = notice( + "HELP REACTION", + f"{member.mention} reacted to {message.jump_url} with {reaction.emoji}", + ) await log.send(self.moderation_channel or message.channel) @Cog.listener() async def on_member_join(self, member) -> None: - """A listener function that checks if a member's account age meets the minimum required age to join the server. + """A listener function that checks if a member's account age meets the minimum required age to join the server. If it doesn't, the member is kicked. :param member: The member who has just joined the server. :type member: discord.Member """ minimum_account_age = app.config.get("moderation", "minimum_account_age") - account_age_in_days = (datetime.now().replace(tzinfo=None) - member.created_at.replace(tzinfo=None)).days + account_age_in_days = ( + datetime.now().replace(tzinfo=None) - member.created_at.replace(tzinfo=None) + ).days if account_age_in_days < minimum_account_age: info(f"{member} kicked due to account age restriction!") log = danger("KICK", f"{member} has been kicked.") - log.add_field("Reason: ", "Automatically kicked due to account age restriction") + log.add_field( + "Reason: ", "Automatically kicked due to account age restriction" + ) - await member.send(f"Your account needs to be {minimum_account_age} days old or more to join the server.") + await member.send( + f"Your account needs to be {minimum_account_age} days old or more to join the server." + ) await member.guild.kick(user=member, reason="Account age restriction") if self.moderation_channel: diff --git a/bot/extensions/pun_cog.py b/bot/extensions/pun_cog.py index f229ca7f..39772c8f 100644 --- a/bot/extensions/pun_cog.py +++ b/bot/extensions/pun_cog.py @@ -1,27 +1,36 @@ -from discord.ext.commands import Cog, has_permissions, hybrid_command, hybrid_group, Context -from discord import Message, Embed +from discord import Embed, Message +from discord.ext.commands import ( + Cog, + Context, + has_permissions, + hybrid_command, + hybrid_group, +) +from emoji import demojize +from nltk.tokenize import TweetTokenizer + from bot.models.bot import BotSettings from bot.models.extensions.language.pun import Pun from bot.models.extensions.language.pun_word import PunWord -from nltk.tokenize import TweetTokenizer -from emoji import demojize -class PunCog(Cog, name="Puns", description="Automatically intrude with puns when triggered"): +class PunCog( + Cog, name="Puns", description="Automatically intrude with puns when triggered" +): def __init__(self, bot): self.bot = bot self.tokenizer = TweetTokenizer() - + @Cog.listener() async def on_message(self, message: Message) -> None: """A listener function that calls the `pun_react` functions when a message is received. - - :param message: The message that was received. - :type message: discord.Message - """ + + :param message: The message that was received. + :type message: discord.Message + """ await self.pun_react(message) - + async def pun_react(self, message: Message) -> None: """Add reactions and send a message in the channel if the message content contains any pun words. @@ -34,26 +43,25 @@ async def pun_react(self, message: Message) -> None: message_tokens = self.tokenizer.tokenize(message.content) tokenlist = set(map(str.lower, message_tokens)) - pun_words = PunWord.all() + pun_words = PunWord.distinct().all() word_set = set(map(lambda pun_word: pun_word.word, pun_words)) matches = tokenlist.intersection(word_set) invoked_at = message.created_at.replace(tzinfo=None) if matches: - matched_pun_words = set(filter(lambda pun_word: pun_word.word in matches, pun_words)) - puns = map(lambda pun_word: Pun.get(pun_word.pun_id), matched_pun_words) - puns = filter(lambda pun: pun.can_invoke_at_time(invoked_at), puns) - puns = set(puns) # remove duplicate puns + matched_pun_words = filter( + lambda pun_word: pun_word.word in matches, pun_words + ) + puns = map(lambda pun_word: Pun.find(pun_word.pun_id), matched_pun_words) + puns = list(filter(lambda pun: pun.can_invoke_at_time(invoked_at), puns)) for pun_word in matched_pun_words: await message.add_reaction(pun_word.emoji()) for pun in puns: embed = Embed( - color=self.bot.default_color, - title=f"Gotcha", - description=pun.text + color=self.bot.default_color, title="Gotcha", description=pun.text ) await message.channel.send(embed=embed) @@ -72,13 +80,14 @@ async def puns_group(self, ctx: Context) -> None: @has_permissions(administrator=True) async def list_puns(self, ctx: Context) -> None: if ctx.invoked_subcommand is None: - pun_texts_with_ids = map(lambda pun: '{}.\t{}'.format( - pun.id, pun.text), Pun.all()) + pun_texts_with_ids = map( + lambda pun: "{}.\t{}".format(pun.id, pun.text), Pun.all() + ) embed = Embed( color=self.bot.default_color, - title=f"Puns", - description="\n".join(pun_texts_with_ids) + title="Puns", + description="\n".join(pun_texts_with_ids), ) await ctx.send(embed=embed) @@ -107,7 +116,7 @@ async def remove_pun(self, ctx: Context, pun_id: int) -> None: :param pun_id: The ID of the pun to which the word will be removed. :type pun_id: str """ - pun = Pun.get(pun_id) + pun = Pun.find(pun_id) if pun: await ctx.send("Pun removed.") @@ -116,7 +125,9 @@ async def remove_pun(self, ctx: Context, pun_id: int) -> None: @puns_group.command(name="add-word", help="Add a pun word to a pun") @has_permissions(administrator=True) - async def add_pun_word(self, ctx: Context, pun_id: int, pun_word: str, emoji: str) -> None: + async def add_pun_word( + self, ctx: Context, pun_id: int, pun_word: str, emoji: str + ) -> None: """Add a new pun word. :param ctx: The context in which the command was called. @@ -128,7 +139,7 @@ async def add_pun_word(self, ctx: Context, pun_id: int, pun_word: str, emoji: st :param emoji: An emoji to be associated with the pun word. :type emoji: str """ - pun = Pun.get(pun_id) + pun = Pun.find(pun_id) if pun: if pun.has_word(pun_word): @@ -151,7 +162,7 @@ async def remove_pun_word(self, ctx: Context, id: int, pun_word: str) -> None: :param pun_word: The old pun word to be removed. :type pun_word: str """ - pun = Pun.get(id) + pun = Pun.find(id) if pun: if not pun.has_word(pun_word): @@ -163,7 +174,9 @@ async def remove_pun_word(self, ctx: Context, id: int, pun_word: str) -> None: await ctx.send(f"Pun with id **{pun.id}** does not exist.") @hybrid_command(name="cooldown", help="Set cooldown for puns feature in minutes.") - async def set_puns_cooldown_command(self, ctx: Context, cooldown_minutes: int) -> None: + async def set_puns_cooldown_command( + self, ctx: Context, cooldown_minutes: int + ) -> None: settings = BotSettings.settings() settings.puns_cooldown = cooldown_minutes settings.save() @@ -172,4 +185,4 @@ async def set_puns_cooldown_command(self, ctx: Context, cooldown_minutes: int) - async def setup(bot): - await bot.add_cog(PunCog(bot)) \ No newline at end of file + await bot.add_cog(PunCog(bot)) diff --git a/bot/extensions/reddit_cog.py b/bot/extensions/reddit_cog.py index 7e2ec2a9..087510bb 100644 --- a/bot/extensions/reddit_cog.py +++ b/bot/extensions/reddit_cog.py @@ -1,45 +1,51 @@ +import re +from typing import List + +from discord import Embed, Message from discord.ext.commands import Cog -from discord import Message, Embed + from bot import app from bot.helpers.log_helper import danger -from typing import List -import re class RedditCog(Cog, name="Reddit", description="Reddit utilities"): def __init__(self, bot): self.bot = bot - self.blacklisted_subreddits = app.config.get("reddit", "blacklist", "").split(';'); + self.blacklisted_subreddits = app.config.get("reddit", "blacklist", "").split( + ";" + ) @property def moderation_channel(self): - """ Returns the moderation channel """ + """Returns the moderation channel""" return self.bot.get_channel_by_name("moderation_logs") async def notify_moderation(self, message: Message, blacklisted: List[str]): - """ Notifies moderators about a blacklisted subreddit mention - - :param message: Message that contained blacklisted subreddits - :type message: Message - :param blacklisted: List of blacklisted subreddits - :type blacklisted: List[str] + """Notifies moderators about a blacklisted subreddit mention + + :param message: Message that contained blacklisted subreddits + :type message: Message + :param blacklisted: List of blacklisted subreddits + :type blacklisted: List[str] """ if self.moderation_channel: - log = danger("BLACKLISTED SUBREDDIT", f"{message.author.mention} mentioned blacklisted subreddits: {', '.join(blacklisted)}\n\nMessage: {message.jump_url}") + log = danger( + "BLACKLISTED SUBREDDIT", + f"{message.author.mention} mentioned blacklisted subreddits: {', '.join(blacklisted)}\n\nMessage: {message.jump_url}", + ) await log.send(self.moderation_channel) async def extract_subreddits(self, message: Message) -> List[List]: - """ Extracts and filters all mentioned subreddits from a message - - :param message: Message from which to extract subreddits - :type message: Message + """Extracts and filters all mentioned subreddits from a message + + :param message: Message from which to extract subreddits + :type message: Message - :returns: List containing both valid and blacklisted subreddits - :rtype: List[List] + :returns: List containing both valid and blacklisted subreddits + :rtype: List[List] """ subreddit_matches = re.findall( - r"(? List[List]: @Cog.listener() async def on_message(self, message: Message): - """ Listens for messages and replies with links to subreddits if any were mentioned + """Listens for messages and replies with links to subreddits if any were mentioned - :param message: Message a user has sent - :type message: Message + :param message: Message a user has sent + :type message: Message """ # Make sure that the message recieved is not sent by Grace - if message.author.id != self.bot.user.id: + if message.author.id != self.bot.user.id: subreddits, blacklisted = await self.extract_subreddits(message) if blacklisted: @@ -70,7 +76,9 @@ async def on_message(self, message: Message): if subreddits: ctx = await self.bot.get_context(message) - subreddit_links = [f"https://www.reddit.com/r/{subreddit}" for subreddit in subreddits] + subreddit_links = [ + f"https://www.reddit.com/r/{subreddit}" for subreddit in subreddits + ] answer_embed = Embed( title="Here're the subreddits you mentioned", diff --git a/bot/extensions/thank_cog.py b/bot/extensions/thank_cog.py index 429768ac..dd4a7dc0 100644 --- a/bot/extensions/thank_cog.py +++ b/bot/extensions/thank_cog.py @@ -1,19 +1,22 @@ from typing import List, Optional -from discord import Member, Embed, Message -from discord.ext.commands import Cog, Context, cooldown, BucketType, hybrid_group, has_permissions + +from discord import Embed, Member +from discord.ext.commands import BucketType, Cog, Context, cooldown, hybrid_group + from bot.extensions.command_error_handler import send_command_help from bot.grace import Grace from bot.models.extensions.thank import Thank class ThankCog(Cog): - """A cog containing thank you commands """ + """A cog containing thank you commands""" + def __init__(self, bot: Grace): self.bot: Grace = bot - @hybrid_group(name='thank', help='Thank commands', invoke_without_command=True) + @hybrid_group(name="thank", help="Thank commands") async def thank_group(self, ctx: Context) -> None: - """Event listener for the `thank` command group. If no subcommand is + """Event listener for the `thank` command group. If no subcommand is invoked, it sends the command help to the user. :param ctx: The context of the command invocation. @@ -22,9 +25,9 @@ async def thank_group(self, ctx: Context) -> None: if ctx.invoked_subcommand is None: await send_command_help(ctx) - @thank_group.command(name='send', description='Send a thank you to a person') + @thank_group.command(name="send", description="Send a thank you to a person") @cooldown(1, 3600, BucketType.user) - async def thank(self, ctx: Context, *, member: Member) -> Optional[Message]: + async def thank(self, ctx: Context, *, member: Member) -> None: """Send a "thank you" message to a member and increase their thank count by 1. :param ctx: The context of the command invocation. @@ -33,71 +36,74 @@ async def thank(self, ctx: Context, *, member: Member) -> Optional[Message]: :type member: Member :return: Message | None """ - if member.id == self.bot.user.id: - return await ctx.send(f'{ctx.author.display_name}, thank you 😊', ephemeral=True) - if ctx.author.id == member.id: - return await ctx.send('You cannot thank yourself.', ephemeral=True) + await ctx.send("You cannot thank yourself.", ephemeral=True) + return - thank: Thank = Thank.get_by(member_id=member.id) + thank: Thank = Thank.find_by(member_id=member.id) - if thank: - thank.count += 1 - thank.save() - else: - thank = Thank.create(member_id=member.id, count=1) + if not thank: + thank = Thank.create(member_id=member.id) + + thank.count += 1 + thank.save() thank_embed: Embed = Embed( - title='INFO', + title="INFO", color=self.bot.default_color, - description=f'{member.display_name}, you were thanked by **{ctx.author.display_name}**\n' - f'Now, your thank count is: **{thank.count}**' + description=f"{member.display_name}, you were thanked by **{ctx.author.display_name}**\n" + f"Now, your thank count is: **{thank.count}**", ) - await member.send(embed=thank_embed) - await ctx.interaction.response.send_message(f'Successfully thanked **@{member.display_name}**', ephemeral=True) + if member.id != self.bot.user.id: + await member.send(embed=thank_embed) + await ctx.send( + f"Successfully thanked **@{member.display_name}**", ephemeral=True + ) - @thank_group.command(name='leaderboard', description='Shows top n helpers.') - async def thank_leaderboard(self, ctx: Context, *, top: int = 10) -> Optional[Message]: + @thank_group.command(name="leaderboard", description="Shows top n helpers.") + async def thank_leaderboard(self, ctx: Context, *, top: int = 10) -> None: """Display the top n helpers, sorted by their thank count. - + :param ctx: The context of the command invocation. :type ctx: Context :param top: The number of top helpers to display. Default is 10. :type top: int (optional) - :return: Message | None """ - helpers: List[Thank] = Thank.ordered() + if top <= 0: + await ctx.reply( + "The top parameter must have value of at least 1.", ephemeral=True + ) + return - if not helpers: - return await ctx.reply('No helpers found.', ephemeral=True) + helpers: List[Thank] = Thank.order_by(count="desc").limit(top).all() - top = min(len(helpers), top) - if top <= 0: - return await ctx.reply('The top parameter must have value of at least 1.', ephemeral=True) + if not helpers: + await ctx.reply("No helpers found.", ephemeral=True) + return leaderboard_embed: Embed = Embed( - title=f'Helpers Leaderboard Top {top}', - description='', - color=self.bot.default_color + title=f"Helpers Leaderboard Top {top}", + description="", + color=self.bot.default_color, ) - for position in range(top): - member = helpers[position] - member_nickname = (await self.bot.fetch_user(member.member_id)).display_name - leaderboard_embed.description += '{}. **{}**: **{}** with {} thank(s).\n'.format( - position + 1, - member_nickname, - member.rank, - member.count + for position, helper in enumerate(helpers): + member = await self.bot.fetch_user(helper.member_id) + leaderboard_embed.description += ( + "{}. **{}**: **{}** with {} thank(s).\n".format( + position + 1, member.display_name, helper.rank, helper.count + ) ) await ctx.reply(embed=leaderboard_embed, ephemeral=True) - @thank_group.command(name='rank', description='Shows your current thank rank.') - async def thank_rank(self, ctx: Context, *, member: Optional[Member] = None) -> None: + @thank_group.command(name="rank", description="Shows your current thank rank.") + async def thank_rank( + self, ctx: Context, *, member: Optional[Member] = None + ) -> None: """Show the current rank of the member who issue this command. - + :param ctx: The context of the command invocation. :type ctx: Context :param member: The member rank. @@ -105,56 +111,49 @@ async def thank_rank(self, ctx: Context, *, member: Optional[Member] = None) -> """ if not member or member.id == ctx.author.id: await self.send_author_rank(ctx) - elif member.id == self.bot.user.id: - await self.send_bot_rank(ctx) else: await self.send_member_rank(ctx, member) - async def send_bot_rank(self, ctx: Context) -> None: - """Send a message showing the rank of the bot. - - :param ctx: The context of the command invocation. - :type ctx: Context - """ - rank_embed: Embed = Embed(title='Grace RANK', color=self.bot.default_color) - rank_embed.description = 'Grace has a range of commands that can help you greatly!\n' \ - 'Rank: **Bot**' - - await ctx.reply(embed=rank_embed, ephemeral=True) - async def send_author_rank(self, ctx: Context) -> None: """Send a message showing the rank of the user who issued the command. - + :param ctx: The context of the command invocation. :type ctx: Context :return: None """ - rank_embed: Embed = Embed(title='YOUR RANK', color=self.bot.default_color) - thank = Thank.get_by(member_id=ctx.author.id) + rank_embed: Embed = Embed(title="YOUR RANK", color=self.bot.default_color) + thank = Thank.find_by(member_id=ctx.author.id) if not thank: - rank_embed.description = 'You haven\'t been thanked yet.' + rank_embed.description = "You haven't been thanked yet." else: - rank_embed.description = f'Your rank is: **{thank.rank}**\n' \ - f'Your thank count is: {thank.count}' + rank_embed.description = ( + f"Your rank is: **{thank.rank}**\nYour thank count is: {thank.count}" + ) await ctx.reply(embed=rank_embed, ephemeral=True) async def send_member_rank(self, ctx: Context, member: Member) -> None: """Send a message showing the rank of the given member. - + :param ctx: The context of the command invocation. :type ctx: Context :param member: The member rank. :type member: Member """ - rank_embed: Embed = Embed(title=f'{member.display_name} RANK', color=self.bot.default_color) - thank = Thank.get_by(member_id=member.id) + rank_embed: Embed = Embed( + title=f"{member.display_name} RANK", color=self.bot.default_color + ) + thank = Thank.find_by(member_id=member.id) if not thank: - rank_embed.description = f'User **@{member.display_name}** hasn\'t been thanked yet.' + rank_embed.description = ( + f"User **@{member.display_name}** has not been thanked yet." + ) else: - rank_embed.description = f'User **@{member.display_name}** has rank: **{thank.rank}**' + rank_embed.description = ( + f"User **@{member.display_name}** has rank: **{thank.rank}**" + ) await ctx.reply(embed=rank_embed, ephemeral=True) diff --git a/bot/extensions/threads_cog.py b/bot/extensions/threads_cog.py index f4642ed5..4398ac1f 100644 --- a/bot/extensions/threads_cog.py +++ b/bot/extensions/threads_cog.py @@ -1,14 +1,15 @@ import traceback -from typing import Optional from logging import info -from pytz import timezone -from discord import Interaction, Embed, TextStyle + +from discord import Embed, Interaction, TextStyle from discord.app_commands import Choice, autocomplete +from discord.ext.commands import Cog, Context, has_permissions, hybrid_group from discord.ui import Modal, TextInput -from discord.ext.commands import Cog, has_permissions, hybrid_command, hybrid_group, Context -from bot.models.extensions.thread import Thread +from pytz import timezone + from bot.classes.recurrence import Recurrence from bot.extensions.command_error_handler import send_command_help +from bot.models.extensions.thread import Thread from lib.config_required import cog_config_required @@ -24,7 +25,7 @@ class ThreadModal(Modal, title="Thread"): label="Content", placeholder="The content of the thread...", min_length=10, - style=TextStyle.paragraph + style=TextStyle.paragraph, ) def __init__(self, recurrence: Recurrence, thread: Thread = None): @@ -47,34 +48,35 @@ async def create_thread(self, interaction: Interaction): thread = Thread.create( title=self.thread_title.value, content=self.thread_content.value, - recurrence=self.thread_recurrence + recurrence=self.thread_recurrence, ) await interaction.response.send_message( - f'Thread __**{thread.id}**__ created!', - ephemeral=True + f"Thread __**{thread.id}**__ created!", ephemeral=True ) async def update_thread(self, interaction: Interaction): - self.thread.title = self.thread_title.value, - self.thread.content = self.thread_content.value, + self.thread.title = (self.thread_title.value,) + self.thread.content = (self.thread_content.value,) self.thread.recurrence = self.thread_recurrence self.thread.save() await interaction.response.send_message( - f'Thread __**{self.thread.id}**__ updated!', - ephemeral=True + f"Thread __**{self.thread.id}**__ updated!", ephemeral=True ) async def on_error(self, interaction: Interaction, error: Exception): - await interaction.response.send_message('Oops! Something went wrong.', ephemeral=True) + await interaction.response.send_message( + "Oops! Something went wrong.", ephemeral=True + ) traceback.print_exception(type(error), error, error.__traceback__) async def thread_autocomplete(_: Interaction, current: str) -> list[Choice[str]]: return [ Choice(name=t.title, value=str(t.id)) - for t in Thread.all() if current.lower() in t.title + for t in Thread.all() + if current.lower() in t.title ] @@ -86,36 +88,37 @@ def __init__(self, bot): self.threads_channel_id = self.required_config self.timezone = timezone("US/Eastern") - def cog_load(self): # Runs everyday at 18:30 - self.jobs.append(self.bot.scheduler.add_job( - self.daily_post, - 'cron', - hour=18, - minute=30, - timezone=self.timezone - )) + self.jobs.append( + self.bot.scheduler.add_job( + self.daily_post, "cron", hour=18, minute=30, timezone=self.timezone + ) + ) # Runs every monday at 18:30 - self.jobs.append(self.bot.scheduler.add_job( - self.weekly_post, - 'cron', - day_of_week='mon', - hour=18, - minute=30, - timezone=self.timezone - )) + self.jobs.append( + self.bot.scheduler.add_job( + self.weekly_post, + "cron", + day_of_week="mon", + hour=18, + minute=30, + timezone=self.timezone, + ) + ) # Runs on the 1st of every month at 18:30 - self.jobs.append(self.bot.scheduler.add_job( - self.monthly_post, - 'cron', - day=1, - hour=18, - minute=30, - timezone=self.timezone - )) + self.jobs.append( + self.bot.scheduler.add_job( + self.monthly_post, + "cron", + day=1, + hour=18, + minute=30, + timezone=self.timezone, + ) + ) def cog_unload(self): for job in self.jobs: @@ -145,9 +148,7 @@ async def post_thread(self, thread: Thread): content = f"<@&{role_id}>" if role_id else None embed = Embed( - color=self.bot.default_color, - title=thread.title, - description=thread.content + color=self.bot.default_color, title=thread.title, description=thread.content ) if channel: @@ -163,17 +164,14 @@ async def threads_group(self, ctx: Context): @threads_group.command(help="List all threads") @has_permissions(administrator=True) async def list(self, ctx: Context): - embed = Embed( - color=self.bot.default_color, - title="Threads" - ) + embed = Embed(color=self.bot.default_color, title="Threads") if threads := Thread.all(): for thread in threads: embed.add_field( name=f"[{thread.id}] {thread.title}", value=f"**Recurrence**: {thread.recurrence}", - inline=False + inline=False, ) else: embed.add_field(name="No threads", value="") @@ -190,7 +188,7 @@ async def create(self, ctx: Context, recurrence: Recurrence): @has_permissions(administrator=True) @autocomplete(thread=thread_autocomplete) async def delete(self, ctx: Context, thread: int): - if thread := Thread.get(thread): + if thread := Thread.find(thread): thread.delete() await ctx.send("Thread successfully deleted!", ephemeral=True) else: @@ -200,25 +198,22 @@ async def delete(self, ctx: Context, thread: int): @has_permissions(administrator=True) @autocomplete(thread=thread_autocomplete) async def update(self, ctx: Context, thread: int, recurrence: Recurrence): - if thread := Thread.get(thread): + if thread := Thread.find(thread): modal = ThreadModal(recurrence, thread=thread) await ctx.interaction.response.send_modal(modal) else: await ctx.send("Thread not found!", ephemeral=True) - @threads_group.command(help="Post a given thread") @has_permissions(administrator=True) @autocomplete(thread=thread_autocomplete) async def post(self, ctx: Context, thread: int): if ctx.interaction: await ctx.interaction.response.send_message( - content="Opening thread!", - delete_after=0, - ephemeral=True + content="Opening thread!", delete_after=0, ephemeral=True ) - if thread := Thread.get(thread): + if thread := Thread.find(thread): await self.post_thread(thread) else: await self.send("Thread not found!") diff --git a/bot/extensions/time_cog.py b/bot/extensions/time_cog.py index 221a7ab0..17826cb0 100644 --- a/bot/extensions/time_cog.py +++ b/bot/extensions/time_cog.py @@ -1,39 +1,35 @@ -import pytz import re - -from discord.ext.commands import Cog from datetime import datetime, timedelta + +import pytz from dateutil import parser from discord import Message +from discord.ext.commands import Cog # Mapping for common timezone abbreviations to their UTC offsets timezone_abbreviations = { # North American - "pst": "America/Los_Angeles", # Pacific Standard Time - "pdt": "America/Los_Angeles", # Pacific Daylight Time - "mst": "America/Denver", # Mountain Standard Time - "mdt": "America/Denver", # Mountain Daylight Time - "cst": "America/Chicago", # Central Standard Time - "cdt": "America/Chicago", # Central Daylight Time - "est": "America/New_York", # Eastern Standard Time - "edt": "America/New_York", # Eastern Daylight Time - + "pst": "America/Los_Angeles", # Pacific Standard Time + "pdt": "America/Los_Angeles", # Pacific Daylight Time + "mst": "America/Denver", # Mountain Standard Time + "mdt": "America/Denver", # Mountain Daylight Time + "cst": "America/Chicago", # Central Standard Time + "cdt": "America/Chicago", # Central Daylight Time + "est": "America/New_York", # Eastern Standard Time + "edt": "America/New_York", # Eastern Daylight Time # International standards - "gmt": "Etc/GMT", # Greenwich Mean Time - "utc": "UTC", # Coordinated Universal Time - + "gmt": "Etc/GMT", # Greenwich Mean Time + "utc": "UTC", # Coordinated Universal Time # European - "bst": "Europe/London", # British Summer Time - "cet": "Europe/Paris", # Central European Time - "cest": "Europe/Paris", # Central European Summer Time - + "bst": "Europe/London", # British Summer Time + "cet": "Europe/Paris", # Central European Time + "cest": "Europe/Paris", # Central European Summer Time # Asia-Pacific - "hkt": "Asia/Hong_Kong", # Hong Kong Time - "ist": "Asia/Kolkata", # India Standard Time - "jst": "Asia/Tokyo", # Japan Standard Time - "aest": "Australia/Sydney", # Australian Eastern Standard Time - "aedt": "Australia/Sydney", # Australian Eastern Daylight Time - + "hkt": "Asia/Hong_Kong", # Hong Kong Time + "ist": "Asia/Kolkata", # India Standard Time + "jst": "Asia/Tokyo", # Japan Standard Time + "aest": "Australia/Sydney", # Australian Eastern Standard Time + "aedt": "Australia/Sydney", # Australian Eastern Daylight Time # TODO: find a way to fetch all timezones dynamically } @@ -41,7 +37,7 @@ class TimeCog( Cog, name="Time", - description="Convert time in messages into UTC-based Discord timestamps." + description="Convert time in messages into UTC-based Discord timestamps.", ): """ A Discord Cog that listens for messages containing time expressions @@ -51,6 +47,7 @@ class TimeCog( This allows users to share time references that automatically display correctly in each user's local timezone within Discord. """ + def __init__(self, bot): self.bot = bot @@ -99,10 +96,10 @@ def _build_relative_date(self, time_str: str, now_utc: datetime) -> str: """ # Handle relative dates if "today" in time_str: - date_str = now_utc.strftime('%Y-%m-%d') + date_str = now_utc.strftime("%Y-%m-%d") time_str = time_str.replace("today", date_str) elif "tomorrow" in time_str: - date_str = (now_utc + timedelta(days=1)).strftime('%Y-%m-%d') + date_str = (now_utc + timedelta(days=1)).strftime("%Y-%m-%d") time_str = time_str.replace("tomorrow", date_str) return time_str diff --git a/bot/extensions/translator_cog.py b/bot/extensions/translator_cog.py index 29447ce8..c807299b 100644 --- a/bot/extensions/translator_cog.py +++ b/bot/extensions/translator_cog.py @@ -1,7 +1,7 @@ -from discord.ext.commands import Cog, hybrid_command, Context, CommandError -from googletrans import Translator, LANGUAGES from discord import Embed, Interaction from discord.app_commands import Choice, autocomplete +from discord.ext.commands import Cog, CommandError, Context, hybrid_command +from googletrans import LANGUAGES, Translator from bot.helpers.error_helper import get_original_exception @@ -20,19 +20,20 @@ async def language_autocomplete(_: Interaction, current: str) -> list[Choice[str return [ Choice(name=language.capitalize(), value=language) - for language in languages[:25] if current.lower() in language.lower() + for language in languages[:25] + if current.lower() in language.lower() ] class TranslatorCog( Cog, name="Translator", - description="Translate a sentence/word from any languages into any languages." + description="Translate a sentence/word from any languages into any languages.", ): @hybrid_command( - name='translator', - help='Translate a sentence/word from any languages into any languages', - usage="sentence={sentence}" + name="translator", + help="Translate a sentence/word from any languages into any languages", + usage="sentence={sentence}", ) @autocomplete(translate_into=language_autocomplete) async def translator(self, ctx: Context, *, sentence: str, translate_into: str): @@ -56,12 +57,12 @@ async def translator(self, ctx: Context, *, sentence: str, translate_into: str): embed.add_field( name=f"{LANGUAGES[translated_text.src].capitalize()} Original", value=sentence.capitalize(), - inline=False + inline=False, ) embed.add_field( name=f"{translate_into} Translation", value=translated_text.text, - inline=False + inline=False, ) await ctx.send(embed=embed) diff --git a/bot/extensions/weather_cog.py b/bot/extensions/weather_cog.py index f02c13b4..66c3beb8 100644 --- a/bot/extensions/weather_cog.py +++ b/bot/extensions/weather_cog.py @@ -1,16 +1,23 @@ -from timezonefinder import TimezoneFinder -from pytz import timezone from datetime import datetime +from string import capwords + +from discord import Embed from discord.ext.commands import Cog, hybrid_command +from pytz import timezone from requests import get -from discord import Embed -from string import capwords +from timezonefinder import TimezoneFinder + from lib.config_required import cog_config_required -@cog_config_required("openweather", "api_key", "Generate yours [here](https://openweathermap.org/api)") -class WeatherCog(Cog, name="Weather", description="get current weather information from a city"): +@cog_config_required( + "openweather", "api_key", "Generate yours [here](https://openweathermap.org/api)" +) +class WeatherCog( + Cog, name="Weather", description="get current weather information from a city" +): """A cog that retrieves current weather information for a given city.""" + OPENWEATHER_BASE_URL = "https://api.openweathermap.org/data/2.5/" def __init__(self, bot): @@ -26,13 +33,11 @@ def get_timezone(data: any) -> datetime: :return: The timezone based on Longitude and Latitude. :rtype: datetime.tzinfo """ - longitude = float(data["coord"]['lon']) - latitude = float(data["coord"]['lat']) + longitude = float(data["coord"]["lon"]) + latitude = float(data["coord"]["lat"]) timezone_finder = TimezoneFinder() - - result = timezone_finder.timezone_at( - lng=longitude, - lat=latitude) + + result = timezone_finder.timezone_at(lng=longitude, lat=latitude) return datetime.now(timezone(str(result))) @staticmethod @@ -66,14 +71,18 @@ async def get_weather(self, city: str): :rtype: dict """ # complete_url to retreive weather info - response = get(f"{self.OPENWEATHER_BASE_URL}/weather?appid={self.api_key}&q={city}") + response = get( + f"{self.OPENWEATHER_BASE_URL}/weather?appid={self.api_key}&q={city}" + ) # code 200 means the city is found otherwise, city is not found if response.status_code == 200: return response.json() return None - @hybrid_command(name='weather', help='Show weather information in your city', usage="{city}") + @hybrid_command( + name="weather", help="Show weather information in your city", usage="{city}" + ) async def weather(self, ctx, *, city_input: str): """Display weather information for the specified city. @@ -86,69 +95,61 @@ async def weather(self, ctx, *, city_input: str): if ctx.interaction: await ctx.interaction.response.defer() - city = capwords(city_input) - data_weather = await self.get_weather(city) + city = capwords(city_input) + data_weather = await self.get_weather(city) timezone_city = self.get_timezone(data_weather) # Now data_weather contains lists of data # from the city inputer by the user if data_weather: - icon_id = data_weather["weather"][0]["icon"] - main = data_weather["main"] - visibility = data_weather['visibility'] + icon_id = data_weather["weather"][0]["icon"] + main = data_weather["main"] + visibility = data_weather["visibility"] current_temperature = main["temp"] fahrenheit = self.kelvin_to_fahrenheit(int(current_temperature)) - celsius = self.kelvin_to_celsius(int(current_temperature)) + celsius = self.kelvin_to_celsius(int(current_temperature)) - feels_like = main["feels_like"] + feels_like = main["feels_like"] feels_like_fahrenheit = self.kelvin_to_fahrenheit(int(feels_like)) - feels_like_celsius = self.kelvin_to_celsius(int(feels_like)) + feels_like_celsius = self.kelvin_to_celsius(int(feels_like)) - current_pressure = main["pressure"] - current_humidity = main["humidity"] - forcast = data_weather["weather"] + current_pressure = main["pressure"] + current_humidity = main["humidity"] + forcast = data_weather["weather"] weather_description = forcast[0]["description"] embed = Embed( color=self.bot.default_color, title=city, - description=timezone_city.strftime('%m/%d/%Y %H:%M'), + description=timezone_city.strftime("%m/%d/%Y %H:%M"), ) - embed.set_image( - url=f'https://openweathermap.org/img/wn/{icon_id}@2x.png' - ) + embed.set_image(url=f"https://openweathermap.org/img/wn/{icon_id}@2x.png") embed.add_field( - name="Description", - value=capwords(weather_description), - inline=False + name="Description", value=capwords(weather_description), inline=False ) embed.add_field( name="Visibility", value=f"{visibility}m | {round(visibility * 3.280839895)}ft", - inline=False + inline=False, ) embed.add_field( name="Temperature", value=f"{round(fahrenheit, 2)}°F | {round(celsius, 2)}°C", - inline=False + inline=False, ) embed.add_field( name="Feels Like", value=f"{round(feels_like_fahrenheit, 2)}°F | {round(feels_like_celsius, 2)}°C", - inline=False + inline=False, ) embed.add_field( name="Atmospheric Pressure", value=f"{current_pressure} hPa", - inline=False - ) - embed.add_field( - name="Humidity", - value=f"{current_humidity}%", - inline=False + inline=False, ) + embed.add_field(name="Humidity", value=f"{current_humidity}%", inline=False) else: embed = Embed( color=self.bot.default_color, diff --git a/bot/extensions/welcome_cog.py b/bot/extensions/welcome_cog.py index 5701513a..01a99665 100755 --- a/bot/extensions/welcome_cog.py +++ b/bot/extensions/welcome_cog.py @@ -1,7 +1,9 @@ -from discord.ext.commands import Cog, hybrid_command from logging import info -from bot.models.channel import Channel + from discord import Embed +from discord.ext.commands import Cog, hybrid_command + +from bot.models.channel import Channel class WelcomeCog(Cog, name="Welcome", description="Welcomes new members"): @@ -18,7 +20,7 @@ def help_section(self): ["posting_guidelines", "help", "resources"], "### Looking for help?\n" "If you need help, read the <#{}> and open a post in <#{}>." - "If you're looking for resources, checkout <#{}> or our [website]()." + "If you're looking for resources, checkout <#{}> or our [website]().", ) @property @@ -30,7 +32,7 @@ def project_section(self): "feel free to come chat with us in <#{}> or visite our [GitHub]().\n" "\n**Our latest projects**:\n" "- [Grace Framework]()\n" - "- [Matrix.py]()\n" + "- [Matrix.py]()\n", ) def get_welcome_message(self, member): @@ -42,11 +44,20 @@ def get_welcome_message(self, member): :return: The welcome message for the given member. :rtype: str """ - return "\n\n".join(filter(None, [ - self.BASE_WELCOME_MESSAGE, - self.help_section, - self.project_section, - ])).strip().format(member_name=member.display_name) + return ( + "\n\n".join( + filter( + None, + [ + self.BASE_WELCOME_MESSAGE, + self.help_section, + self.project_section, + ], + ) + ) + .strip() + .format(member_name=member.display_name) + ) def __build_section(self, channel_names, message): """Builds a section of the welcome message by replacing placeholders with corresponding channel IDs. @@ -66,7 +77,10 @@ def __build_section(self, channel_names, message): :return: The constructed section of the welcome message with channel IDs inserted. :rtype: str """ - channel_ids = [getattr(Channel.get_by(channel_name=n), "channel_id", "") for n in channel_names] + channel_ids = [ + getattr(Channel.find_by(channel_name=n), "channel_id", "") + for n in channel_names + ] return message.format(*channel_ids) if all(channel_ids) else "" @Cog.listener() @@ -90,9 +104,12 @@ async def on_member_update(self, before, after): embed.add_field( name="Welcome to **The Code Society Server**", value=self.get_welcome_message(after), - inline=False + inline=False, + ) + embed.set_footer( + text="https://github.com/Code-Society-Lab/grace", + icon_url="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png", ) - embed.set_footer(text="https://github.com/Code-Society-Lab/grace", icon_url="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png") await welcome_channel.send(f"<@{after.id}>", embed=embed) @@ -105,7 +122,9 @@ async def on_member_join(self, member): """ info(f"{member.display_name} joined the server!") - @hybrid_command(name="welcome", description="Welcomes the person who issues the command") + @hybrid_command( + name="welcome", description="Welcomes the person who issues the command" + ) async def welcome_command(self, ctx): """Send a welcome message to the person who issued the command. @@ -119,7 +138,10 @@ async def welcome_command(self, ctx): title="Welcome to **The Code Society Server**", description=self.get_welcome_message(ctx.author), ) - embed.set_footer(text="https://github.com/Code-Society-Lab/grace", icon_url="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png") + embed.set_footer( + text="https://github.com/Code-Society-Lab/grace", + icon_url="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png", + ) await ctx.send(embed=embed, ephemeral=True) diff --git a/bot/extensions/wikipedia_cog.py b/bot/extensions/wikipedia_cog.py index 9d36e155..2799446f 100644 --- a/bot/extensions/wikipedia_cog.py +++ b/bot/extensions/wikipedia_cog.py @@ -1,15 +1,16 @@ -from typing import List, Any -from discord.ext.commands import Cog, hybrid_command, Context -from discord.ui import View -from discord import ButtonStyle, ui, Embed, Interaction, Button -from urllib.request import urlopen -from urllib.parse import quote_plus from json import loads +from typing import Any, List +from urllib.parse import quote_plus +from urllib.request import urlopen + +from discord import Button, ButtonStyle, Embed, Interaction, ui +from discord.ext.commands import Cog, Context, hybrid_command +from discord.ui import View def search_results(search: str) -> List[Any]: """Return search results from Wikipedia for the given search query. - + :param search: The search query to be used to search Wikipedia. :type search: str @@ -30,7 +31,9 @@ def __init__(self, search: str, result: List[Any]) -> None: self.search: str = search self.result: List[Any] = result - async def wiki_result(self, interaction: Interaction, _: Button, index: int) -> None: + async def wiki_result( + self, interaction: Interaction, _: Button, index: int + ) -> None: """Send the selected search result to the user. :param _: The Button clicked @@ -41,23 +44,24 @@ async def wiki_result(self, interaction: Interaction, _: Button, index: int) -> :type index: int """ if len(self.result[3]) >= index: - await interaction.response.send_message("{mention} requested:\n {request}".format( - mention=interaction.user.mention, - request=self.result[3][index-1] - )) + await interaction.response.send_message( + "{mention} requested:\n {request}".format( + mention=interaction.user.mention, request=self.result[3][index - 1] + ) + ) self.stop() else: await interaction.response.send_message("Invalid choice.", ephemeral=True) - @ui.button(label='1', style=ButtonStyle.primary) + @ui.button(label="1", style=ButtonStyle.primary) async def first_wiki_result(self, interaction: Interaction, button: Button): await self.wiki_result(interaction, button, 1) - @ui.button(label='2', style=ButtonStyle.primary) + @ui.button(label="2", style=ButtonStyle.primary) async def second_wiki_result(self, interaction: Interaction, button: Button): await self.wiki_result(interaction, button, 2) - @ui.button(label='3', style=ButtonStyle.primary) + @ui.button(label="3", style=ButtonStyle.primary) async def third_wiki_result(self, interaction: Interaction, button: Button): await self.wiki_result(interaction, button, 3) @@ -66,7 +70,10 @@ class Wikipedia(Cog, name="Wikipedia", description="Search on Wikipedia."): def __init__(self, bot): self.bot = bot - @hybrid_command(name="wiki", description="Searches and displays the first 3 results from Wikipedia.") + @hybrid_command( + name="wiki", + description="Searches and displays the first 3 results from Wikipedia.", + ) async def wiki(self, ctx: Context, *, search: str) -> None: """Search Wikipedia and display the first 3 search results to the user. @@ -79,7 +86,9 @@ async def wiki(self, ctx: Context, *, search: str) -> None: view: Buttons = Buttons(search, result) if len(result[1]) == 0: - await ctx.interaction.response.send_message("No result found.", ephemeral=True) + await ctx.interaction.response.send_message( + "No result found.", ephemeral=True + ) else: result_view = "" search_count = 1 @@ -88,8 +97,8 @@ async def wiki(self, ctx: Context, *, search: str) -> None: search_count += 1 embed = Embed( - color=0x2376ff, - title=f"Top 3 Wikipedia Search", + color=0x2376FF, + title="Top 3 Wikipedia Search", description=result_view, ) await ctx.send(embed=embed, view=view, ephemeral=True) diff --git a/bot/grace.py b/bot/grace.py index 619ec510..1e9a61dc 100755 --- a/bot/grace.py +++ b/bot/grace.py @@ -1,9 +1,11 @@ -from grace.bot import Bot from logging import info, warning -from discord import Intents, Colour, Activity, ActivityType + +from discord import Activity, ActivityType, Colour, Intents from pretty_help import PrettyHelp + from bot.models.channel import Channel from bot.models.extension import Extension +from grace.bot import Bot class Grace(Bot): @@ -11,7 +13,7 @@ def __init__(self, app): super().__init__( app, intents=Intents.all(), - activity=Activity(type=ActivityType.playing, name="::help") + activity=Activity(type=ActivityType.playing, name="::help"), ) self.help_command = PrettyHelp(color=self.default_color) @@ -21,7 +23,7 @@ def default_color(self): return Colour.from_str(self.config.get("default_color")) def get_channel_by_name(self, name): - channel = Channel.get_by(channel_name=name) + channel = Channel.find_by(channel_name=name) if channel: return self.get_channel(channel.channel_id) @@ -29,12 +31,10 @@ def get_channel_by_name(self, name): async def load_extensions(self): for module in self.app.extension_modules: - extension = Extension.get_by(module_name=module) + extension = Extension.where(module_name=module).first() if not extension: - warning( - f"{module} is not registered. Registering the extension." - ) + warning(f"{module} is not registered. Registering the extension.") extension = Extension.create(module_name=module) if not extension.should_be_loaded(): diff --git a/bot/helpers/error_helper.py b/bot/helpers/error_helper.py index a34b5e54..4b776b10 100644 --- a/bot/helpers/error_helper.py +++ b/bot/helpers/error_helper.py @@ -1,11 +1,11 @@ -from discord import Embed, Color, DiscordException +from discord import Color, DiscordException, Embed async def send_error(ctx, error_description, **kwargs): embed = Embed( title="Oops! An error occurred", color=Color.red(), - description=error_description + description=error_description, ) for key, value in kwargs.items(): @@ -15,11 +15,13 @@ async def send_error(ctx, error_description, **kwargs): async def send_command_error(ctx, error_description, command, argument_example=None): - await send_error(ctx, error_description, example=f"```/{command} {argument_example}```") + await send_error( + ctx, error_description, example=f"```/{command} {argument_example}```" + ) # This might be the right place for this function def get_original_exception(error: DiscordException) -> Exception: - while hasattr(error, 'original'): + while hasattr(error, "original"): error = error.original return error diff --git a/bot/helpers/github_helper.py b/bot/helpers/github_helper.py index ae5e50bd..4df835fb 100644 --- a/bot/helpers/github_helper.py +++ b/bot/helpers/github_helper.py @@ -1,10 +1,12 @@ +from math import ceil from typing import Iterable, List -from discord import Embed, Color + +from discord import Color, Embed from discord.ui import Button from emoji import emojize -from github import Repository, Organization +from github import Organization, Repository + from bot.services.github_service import GithubService -from math import ceil def available_project_names() -> Iterable[str]: @@ -25,7 +27,7 @@ def create_contributors_embeds(repository: Repository) -> List[Embed]: for i in range(page_count): embed: Embed = Embed( - color=Color.from_str("#171515"), # github color + color=Color.from_str("#171515"), # github color title=f"{repository.name.capitalize()}'s Contributors", ) @@ -33,7 +35,7 @@ def create_contributors_embeds(repository: Repository) -> List[Embed]: embed.add_field( name=contributor.login, value=f"{contributor.contributions} Contributions", - inline=True + inline=True, ) embeds.append(embed) @@ -43,7 +45,5 @@ def create_contributors_embeds(repository: Repository) -> List[Embed]: def create_repository_button(repository: Repository) -> Button: return Button( - emoji=emojize(":file_folder:"), - label=f"Repository", - url=repository.html_url + emoji=emojize(":file_folder:"), label="Repository", url=repository.html_url ) diff --git a/bot/helpers/log_helper.py b/bot/helpers/log_helper.py index bc73a737..8857eb96 100644 --- a/bot/helpers/log_helper.py +++ b/bot/helpers/log_helper.py @@ -1,14 +1,15 @@ -from discord import Embed, Color from datetime import datetime +from discord import Color, Embed + def info(title, description): - # Will be deprected in favor of notice + # Will be deprected in favor of notice return LogHelper(title, description, "info") def notice(title, description): - return LogHelper(title, description, "info") + return LogHelper(title, description, "info") def warning(title, description): @@ -32,15 +33,11 @@ def __init__(self, title, description, log_level="info"): title=title, description=description, color=self.COLORS_BY_LOG_LEVEL.get(log_level, self.__DEFAULT_COLOR), - timestamp=datetime.utcnow() + timestamp=datetime.utcnow(), ) def add_field(self, name, value): - self.embed.add_field( - name=name, - value=value, - inline=False - ) + self.embed.add_field(name=name, value=value, inline=False) async def send(self, channel): await channel.send(embed=self.embed) diff --git a/bot/models/bot.py b/bot/models/bot.py index 2f2b3b49..75bb04a8 100644 --- a/bot/models/bot.py +++ b/bot/models/bot.py @@ -1,17 +1,16 @@ -from sqlalchemy import Integer, Column, BigInteger -from grace.model import Model -from bot import app +from grace.model import Field, Model -class BotSettings(app.base, Model): +class BotSettings(Model): """Configurable settings for each server""" - __tablename__ = 'bot_settings' - id = Column(Integer, primary_key=True) - puns_cooldown = Column(BigInteger, default=60) + __tablename__ = "bot_settings" + + id: int | None = Field(default=None, primary_key=True) + puns_cooldown: int = Field(default=60) @classmethod def settings(self): - '''Since grace runs on only one settings record per bot, - this is a semantic shortcut to get the first record.''' + """Since grace runs on only one settings record per bot, + this is a semantic shortcut to get the first record.""" return self.first() diff --git a/bot/models/channel.py b/bot/models/channel.py index e5f2c6c0..635cb908 100644 --- a/bot/models/channel.py +++ b/bot/models/channel.py @@ -1,12 +1,12 @@ -from sqlalchemy import String, Column, UniqueConstraint, BigInteger -from grace.model import Model -from bot import app +from sqlalchemy import UniqueConstraint +from grace.model import Field, Model -class Channel(app.base, Model): - __tablename__ = 'channels' - channel_name = Column(String(255), primary_key=True) - channel_id = Column(BigInteger, primary_key=True) +class Channel(Model): + __tablename__ = "channels" + + channel_name: str = Field(primary_key=True) + channel_id: int = Field(primary_key=True) UniqueConstraint("channel_name", "channel_id", name="uq_id_cn_cid") diff --git a/bot/models/extension.py b/bot/models/extension.py index 1aade78f..8252bd18 100644 --- a/bot/models/extension.py +++ b/bot/models/extension.py @@ -1,20 +1,19 @@ -from sqlalchemy import Integer, Column, String -from grace.model import Model from bot import app from bot.classes.state import State +from grace.model import Field, Model +from lib.fields import EnumField -class Extension(app.base, Model): - """Extension model (With SQLAlchemy ORM)""" +class Extension(Model): __tablename__ = "extensions" - id = Column(Integer, primary_key=True) - module_name = Column(String(255), nullable=False, unique=True) - _state = Column("state", Integer, default=1) + id: int | None = Field(default=None, primary_key=True) + module_name: str = Field(nullable=False, unique=True) + state: State = EnumField(State, default=State.ENABLED) @classmethod def by_state(cls, state): - return cls.where(_state=state.value) + return cls.where(state=state) @property def name(self): @@ -24,14 +23,6 @@ def name(self): def short_module_name(self): return self.module_name.removeprefix("bot.extensions.") - @property - def state(self): - return State(self._state) - - @state.setter - def state(self, new_state): - self._state = new_state.value - @property def module(self): return app.get_extension_module(self.module_name) diff --git a/bot/models/extensions/fun/answer.py b/bot/models/extensions/fun/answer.py index ced60e15..79374a3f 100644 --- a/bot/models/extensions/fun/answer.py +++ b/bot/models/extensions/fun/answer.py @@ -1,11 +1,10 @@ -from sqlalchemy import Column, Integer, String -from grace.model import Model -from bot import app +from typing import Optional +from grace.model import Field, Model -class Answer(app.base, Model): - """Answer model (With SQLAlchemy ORM)""" + +class Answer(Model): __tablename__ = "answers" - id = Column(Integer, primary_key=True) - answer = Column(String(255), nullable=False) + id: Optional[int] = Field(default=None, primary_key=True) + answer: str = Field(max_length=255) diff --git a/bot/models/extensions/language/pun.py b/bot/models/extensions/language/pun.py index f6ebb765..bc9ecb6c 100644 --- a/bot/models/extensions/language/pun.py +++ b/bot/models/extensions/language/pun.py @@ -1,19 +1,20 @@ -from datetime import timedelta -from sqlalchemy import Text, Column, Integer, DateTime -from sqlalchemy.orm import relationship -from grace.model import Model -from bot import app -from bot.models.extensions.language.pun_word import PunWord +from datetime import datetime, timedelta +from typing import List + from bot.models.bot import BotSettings +from bot.models.extensions.language.pun_word import PunWord +from grace.model import Field, Model, Relationship -class Pun(app.base, Model): +class Pun(Model): __tablename__ = "puns" - id = Column(Integer, primary_key=True) - text = Column(Text(), unique=True) - last_invoked = Column(DateTime) - pun_words = relationship("PunWord", lazy="dynamic", cascade="all, delete-orphan") + id: int | None = Field(default=None, primary_key=True) + text: str = Field(sa_column_kwargs={"unique": True}) + last_invoked: datetime | None = Field(default=None) + pun_words: List["PunWord"] = Relationship( + back_populates="pun", sa_relationship_kwargs={"lazy": "selectin"} + ) @property def words(self): @@ -21,7 +22,7 @@ def words(self): yield pun_word.word def has_word(self, word): - return self.pun_words.filter(PunWord.word == word).count() > 0 + return self.pun_words.where(word=word).count() > 0 def add_pun_word(self, pun_word, emoji_code): PunWord(pun_id=self.id, word=pun_word, emoji_code=emoji_code).save() diff --git a/bot/models/extensions/language/pun_word.py b/bot/models/extensions/language/pun_word.py index dc5e6d05..4e5edb37 100644 --- a/bot/models/extensions/language/pun_word.py +++ b/bot/models/extensions/language/pun_word.py @@ -1,16 +1,23 @@ +from typing import TYPE_CHECKING + from emoji import emojize -from sqlalchemy import Integer, String, Column, ForeignKey -from grace.model import Model -from bot import app + +from grace.model import Field, Model, Relationship + +if TYPE_CHECKING: + from bot.models.extensions.language.pun import Pun -class PunWord(app.base, Model): - __tablename__ = 'pun_words' +class PunWord(Model): + __tablename__ = "pun_words" - id = Column(Integer, primary_key=True) - pun_id = Column(ForeignKey("puns.id")) - word = Column(String(255), nullable=False) - emoji_code = Column(String(255)) + id: int | None = Field(default=None, primary_key=True) + pun_id: int = Field(foreign_key="puns.id") + word: str = Field(max_length=255) + emoji_code: str | None = Field(default=None, max_length=255) + pun: "Pun" = Relationship( + back_populates="pun_words", sa_relationship_kwargs={"lazy": "selectin"} + ) def emoji(self): - return emojize(self.emoji_code, language='alias') \ No newline at end of file + return emojize(self.emoji_code, language="alias") diff --git a/bot/models/extensions/language/trigger.py b/bot/models/extensions/language/trigger.py index 1cb6bef9..fdbdacb4 100644 --- a/bot/models/extensions/language/trigger.py +++ b/bot/models/extensions/language/trigger.py @@ -1,19 +1,22 @@ +from typing import List + from emoji import emojize -from sqlalchemy import String, Column, Integer -from sqlalchemy.orm import relationship -from grace.model import Model -from bot import app + from bot.models.extensions.language.trigger_word import TriggerWord +from grace.model import Field, Model, Relationship + +class Trigger(Model): + __tablename__ = "triggers" -class Trigger(app.base, Model): - __tablename__ = 'triggers' + id: int | None = Field(default=None, primary_key=True) + name: str = Field(max_length=255, unique=True) + positive_emoji_code: str = Field(max_length=255) + negative_emoji_code: str = Field(max_length=255) - id = Column(Integer, primary_key=True) - name = Column(String(255), unique=True) - positive_emoji_code = Column(String(255), nullable=False) - negative_emoji_code = Column(String(255), nullable=False) - trigger_words = relationship("TriggerWord") + trigger_words: List[TriggerWord] = Relationship( + back_populates="trigger", sa_relationship_kwargs={"lazy": "selectin"} + ) @property def words(self): @@ -22,11 +25,11 @@ def words(self): @property def positive_emoji(self): - return emojize(self.positive_emoji_code, language='alias') + return emojize(self.positive_emoji_code, language="alias") @property def negative_emoji(self): - return emojize(self.negative_emoji_code, language='alias') + return emojize(self.negative_emoji_code, language="alias") def add_trigger_word(self, trigger_word): TriggerWord(trigger_id=self.id, word=trigger_word).save() diff --git a/bot/models/extensions/language/trigger_word.py b/bot/models/extensions/language/trigger_word.py index 8d5ba336..770d4ba4 100644 --- a/bot/models/extensions/language/trigger_word.py +++ b/bot/models/extensions/language/trigger_word.py @@ -1,10 +1,17 @@ -from sqlalchemy import String, Column, ForeignKey -from grace.model import Model -from bot import app +from typing import TYPE_CHECKING, Optional +from grace.model import Field, Model, Relationship -class TriggerWord(app.base, Model): - __tablename__ = 'trigger_words' +if TYPE_CHECKING: + from .trigger import Trigger - trigger_id = Column(ForeignKey("triggers.id"), primary_key=True) - word = Column(String(255), primary_key=True) + +class TriggerWord(Model): + __tablename__ = "trigger_words" + + trigger_id: int = Field(foreign_key="triggers.id", primary_key=True) + word: str = Field(max_length=255, primary_key=True) + + trigger: Optional["Trigger"] = Relationship( + back_populates="trigger_words", sa_relationship_kwargs={"lazy": "selectin"} + ) diff --git a/bot/models/extensions/thank.py b/bot/models/extensions/thank.py index b549565c..be8053d1 100644 --- a/bot/models/extensions/thank.py +++ b/bot/models/extensions/thank.py @@ -1,42 +1,32 @@ -from typing import Optional, List -from sqlalchemy import desc, Column, Integer, BigInteger -from grace.model import Model -from bot import app +from typing import Optional +from grace.model import Field, Model -class Thank(app.base, Model): + +class Thank(Model): """A class representing a Thank record in the database.""" - __tablename__ = 'thanks' - id = Column(Integer, primary_key=True) - member_id = Column(BigInteger, nullable=False, unique=True) - count = Column(Integer, default=0) + __tablename__ = "thanks" + + id: int | None = Field(default=None, primary_key=True) + member_id: int = Field(unique=True) + count: int = Field(default=0) @property def rank(self) -> Optional[str]: """Returns the rank of the member based on the number of times they have been thanked. - + :return: The rank of the member. :rtype: Optional[str] """ if self.count in range(1, 11): - return 'Intern' + return "Intern" elif self.count in range(11, 21): - return 'Helper' + return "Helper" elif self.count in range(21, 31): - return 'Vetted helper' + return "Vetted helper" elif self.count > 30: - return 'Expert' + return "Expert" else: return None - - @classmethod - def ordered(cls) -> List['Thank']: - """Returns a list of all `Thank` objects in the database, ordered by - the `count` attribute in descending order. - - :return: A list of `Thank` objects. - :rtype: List[Thank] - """ - return cls.query().order_by(desc(cls.count)).all() diff --git a/bot/models/extensions/thread.py b/bot/models/extensions/thread.py index 6321ef33..217b0f96 100644 --- a/bot/models/extensions/thread.py +++ b/bot/models/extensions/thread.py @@ -1,25 +1,17 @@ -from sqlalchemy import Column, Integer, String, Text -from grace.model import Model -from bot import app -from bot.classes.recurrence import Recurrence - +from sqlalchemy import Text -class Thread(app.base, Model): - __tablename__ = 'threads' - - id = Column(Integer, primary_key=True) - title = Column(String, nullable=False,) - content = Column(Text, nullable=False,) - _recurrence = Column("recurrence", Integer, nullable=False, default=0) +from bot.classes.recurrence import Recurrence +from grace.model import Field, Model +from lib.fields import EnumField - @property - def recurrence(self) -> Recurrence: - return Recurrence(self._recurrence) +class Thread(Model): + __tablename__ = "threads" - @recurrence.setter - def recurrence(self, new_recurrence: Recurrence): - self._recurrence = new_recurrence.value + id: int | None = Field(default=None, primary_key=True) + title: str + content: str = Field(sa_type=Text) + recurrence: Recurrence = EnumField(Recurrence, default=Recurrence.NONE) @classmethod - def find_by_recurrence(cls, recurrence: Recurrence) -> 'Recurrence': - return cls.where(_recurrence=recurrence.value) \ No newline at end of file + def find_by_recurrence(cls, recurrence: Recurrence) -> "Recurrence": + return cls.where(recurrence=recurrence.value) diff --git a/bot/services/github_service.py b/bot/services/github_service.py index b20d5b08..c9978207 100644 --- a/bot/services/github_service.py +++ b/bot/services/github_service.py @@ -1,6 +1,8 @@ -from typing import Union, Optional +from typing import Optional, Union + from github import Github, Organization from github.Repository import Repository + from bot import app diff --git a/bot/services/mermaid_service.py b/bot/services/mermaid_service.py index 0be6ad8a..08589be9 100644 --- a/bot/services/mermaid_service.py +++ b/bot/services/mermaid_service.py @@ -1,9 +1,9 @@ import base64 import json import zlib -import requests -from logging import info, critical +from logging import critical, info +import requests MERMAID_API = "https://mermaid.ink" @@ -17,19 +17,16 @@ def _encode_diagram(diagram: str) -> str: :returns: Pako-compressed and Base64-encoded diagram string. :rtype: str """ - graph_json = { - "code": diagram, - "mermaid": {"theme": "default"} - } + graph_json = {"code": diagram, "mermaid": {"theme": "default"}} - byte_data = json.dumps(graph_json).encode('ascii') + byte_data = json.dumps(graph_json).encode("ascii") compressed_data = zlib.compress(byte_data, level=9) - b64_encoded = base64.b64encode(compressed_data).decode('ascii') + b64_encoded = base64.b64encode(compressed_data).decode("ascii") - return b64_encoded.replace('+', '-').replace('/', '_').strip('=') + return b64_encoded.replace("+", "-").replace("/", "_").strip("=") -def _build_url(diagram: str, type: str = 'img') -> str: +def _build_url(diagram: str, type: str = "img") -> str: """Build the Mermaid.ink API URL for a given diagram. :param diagram: Mermaid diagram definition. diff --git a/db/alembic/env.py b/db/alembic/env.py index 1c4b4b8c..17f18629 100644 --- a/db/alembic/env.py +++ b/db/alembic/env.py @@ -1,9 +1,7 @@ from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool - from alembic import context +from sqlalchemy import engine_from_config, pool from bot import app @@ -23,6 +21,7 @@ # target_metadata = None target_metadata = app.base.metadata + # other values from the config, defined by the needs of env.py, # can be acquired: # my_important_option = config.get_main_option("my_important_option") @@ -69,13 +68,11 @@ def run_migrations_online() -> None: config.get_section(config.config_ini_section), prefix="sqlalchemy.", poolclass=pool.NullPool, - url=app.config.database_uri + url=app.config.database_uri, ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/db/alembic/versions/11f3c9cd0977_added_puns_tables.py b/db/alembic/versions/11f3c9cd0977_added_puns_tables.py index ae8194e4..f07ed708 100644 --- a/db/alembic/versions/11f3c9cd0977_added_puns_tables.py +++ b/db/alembic/versions/11f3c9cd0977_added_puns_tables.py @@ -1,16 +1,16 @@ """Added puns tables Revision ID: 11f3c9cd0977 -Revises: +Revises: Create Date: 2022-11-08 19:39:27.524172 """ -from alembic import op -import sqlalchemy as sa +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = '11f3c9cd0977' +revision = "11f3c9cd0977" down_revision = None branch_labels = None depends_on = None @@ -19,28 +19,29 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( - 'puns', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('text', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('text'), - if_not_exists=True + "puns", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("text", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("text"), ) op.create_table( - 'pun_words', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('pun_id', sa.Integer(), nullable=True), - sa.Column('word', sa.String(length=255), nullable=False), - sa.Column('emoji_code', sa.String(length=255), nullable=True), - sa.ForeignKeyConstraint(['pun_id'], ['puns.id'], ), - sa.PrimaryKeyConstraint('id'), - if_not_exists=True + "pun_words", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("pun_id", sa.Integer(), nullable=True), + sa.Column("word", sa.String(length=255), nullable=False), + sa.Column("emoji_code", sa.String(length=255), nullable=True), + sa.ForeignKeyConstraint( + ["pun_id"], + ["puns.id"], + ), + sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('pun_words') - op.drop_table('puns') + op.drop_table("pun_words") + op.drop_table("puns") # ### end Alembic commands ### diff --git a/db/alembic/versions/381d2407fcf3_create_thanks_tables.py b/db/alembic/versions/381d2407fcf3_create_thanks_tables.py index ae90e917..bb81fb14 100644 --- a/db/alembic/versions/381d2407fcf3_create_thanks_tables.py +++ b/db/alembic/versions/381d2407fcf3_create_thanks_tables.py @@ -5,13 +5,13 @@ Create Date: 2022-12-10 01:52:25.646625 """ -from alembic import op -import sqlalchemy as sa +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = '381d2407fcf3' -down_revision = '11f3c9cd0977' +revision = "381d2407fcf3" +down_revision = "11f3c9cd0977" branch_labels = None depends_on = None @@ -19,18 +19,17 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( - 'thanks', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('member_id', sa.BigInteger(), nullable=False), - sa.Column('count', sa.Integer(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('member_id'), - if_not_exists=True + "thanks", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("member_id", sa.BigInteger(), nullable=False), + sa.Column("count", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("member_id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('thanks') + op.drop_table("thanks") # ### end Alembic commands ### diff --git a/db/alembic/versions/614bb9e370d8_add_settings_for_puns_cooldown.py b/db/alembic/versions/614bb9e370d8_add_settings_for_puns_cooldown.py index bcde8056..b8f8631a 100644 --- a/db/alembic/versions/614bb9e370d8_add_settings_for_puns_cooldown.py +++ b/db/alembic/versions/614bb9e370d8_add_settings_for_puns_cooldown.py @@ -5,14 +5,14 @@ Create Date: 2023-05-29 20:55:26.456843 """ -from alembic import op + import sqlalchemy as sa +from alembic import op from sqlalchemy.engine.reflection import Inspector - # revision identifiers, used by Alembic. -revision = '614bb9e370d8' -down_revision = '381d2407fcf3' +revision = "614bb9e370d8" +down_revision = "381d2407fcf3" branch_labels = None depends_on = None @@ -20,39 +20,32 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( - 'bot_settings', - sa.Column('id', sa.Integer(), nullable=False), + "bot_settings", + sa.Column("id", sa.Integer(), nullable=False), sa.Column( - 'puns_cooldown', - sa.BigInteger(), - nullable=False, - server_default="60" + "puns_cooldown", sa.BigInteger(), nullable=False, server_default="60" ), - sa.PrimaryKeyConstraint('id'), - if_not_exists=True + sa.PrimaryKeyConstraint("id"), ) # check if column exists before adding bind = op.get_bind() inspector = Inspector.from_engine(bind) - columns = [col['name'] for col in inspector.get_columns('puns')] - if 'last_invoked' not in columns: - op.add_column( - 'puns', - sa.Column('last_invoked', sa.DateTime(), nullable=True) - ) + columns = [col["name"] for col in inspector.get_columns("puns")] + if "last_invoked" not in columns: + op.add_column("puns", sa.Column("last_invoked", sa.DateTime(), nullable=True)) result = bind.execute( sa.text("SELECT id FROM bot_settings WHERE id = 1") ).fetchone() - + if not result: op.execute("INSERT INTO bot_settings (id) VALUES (1)") def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('puns', 'last_invoked') - op.drop_table('bot_settings') + op.drop_column("puns", "last_invoked") + op.drop_table("bot_settings") # ### end Alembic commands ### diff --git a/db/alembic/versions/cc8da39749e7_missing_pre_migrations_tables.py b/db/alembic/versions/cc8da39749e7_missing_pre_migrations_tables.py index 5613edf5..6c591f76 100644 --- a/db/alembic/versions/cc8da39749e7_missing_pre_migrations_tables.py +++ b/db/alembic/versions/cc8da39749e7_missing_pre_migrations_tables.py @@ -5,13 +5,13 @@ Create Date: 2025-09-16 00:17:25.001017 """ -from alembic import op -import sqlalchemy as sa +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = 'cc8da39749e7' -down_revision = 'f8ac0bbc34ac' +revision = "cc8da39749e7" +down_revision = "f8ac0bbc34ac" branch_labels = None depends_on = None @@ -23,26 +23,14 @@ def upgrade() -> None: sa.Column("id", sa.Integer(), primary_key=True, nullable=False), sa.Column("module_name", sa.String(255), nullable=False, unique=True), sa.Column("state", sa.Integer(), nullable=True, server_default="1"), - if_not_exists=True, ) # --- channels table --- op.create_table( "channels", - sa.Column( - "channel_name", - sa.String(255), - primary_key=True, - nullable=False - ), - sa.Column( - "channel_id", - sa.BigInteger(), - primary_key=True, - nullable=False - ), + sa.Column("channel_name", sa.String(255), primary_key=True, nullable=False), + sa.Column("channel_id", sa.BigInteger(), primary_key=True, nullable=False), sa.UniqueConstraint("channel_name", "channel_id", name="uq_id_cn_cid"), - if_not_exists=True, ) # ensure there’s always a single settings row @@ -59,7 +47,6 @@ def upgrade() -> None: "answers", sa.Column("id", sa.Integer(), primary_key=True, nullable=False), sa.Column("answer", sa.String(255), nullable=False), - if_not_exists=True, ) # --- triggers table --- @@ -69,7 +56,6 @@ def upgrade() -> None: sa.Column("name", sa.String(255), unique=True), sa.Column("positive_emoji_code", sa.String(255), nullable=False), sa.Column("negative_emoji_code", sa.String(255), nullable=False), - if_not_exists=True, ) # --- trigger_words table --- @@ -80,10 +66,9 @@ def upgrade() -> None: sa.Integer(), sa.ForeignKey("triggers.id"), primary_key=True, - nullable=False + nullable=False, ), sa.Column("word", sa.String(255), primary_key=True, nullable=False), - if_not_exists=True, ) diff --git a/db/alembic/versions/f8ac0bbc34ac_create_threads.py b/db/alembic/versions/f8ac0bbc34ac_create_threads.py index ee3914ba..79fe2c75 100644 --- a/db/alembic/versions/f8ac0bbc34ac_create_threads.py +++ b/db/alembic/versions/f8ac0bbc34ac_create_threads.py @@ -5,28 +5,27 @@ Create Date: 2025-03-10 20:34:24.702582 """ -from alembic import op -import sqlalchemy as sa +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = 'f8ac0bbc34ac' -down_revision = '614bb9e370d8' +revision = "f8ac0bbc34ac" +down_revision = "614bb9e370d8" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( - 'threads', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('title', sa.String(255), nullable=False), - sa.Column('content', sa.Text(), nullable=False), - sa.Column('recurrence', sa.Integer(), nullable=False, default=0), - sa.PrimaryKeyConstraint('id'), - if_not_exists=True + "threads", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("title", sa.String(255), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("recurrence", sa.Integer(), nullable=False, default=0), + sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: - op.drop_table('threads') + op.drop_table("threads") diff --git a/db/seed.py b/db/seed.py index d4d65b3b..3c077bc9 100644 --- a/db/seed.py +++ b/db/seed.py @@ -18,6 +18,7 @@ def seed_database(): import importlib import pkgutil + from db import seeds diff --git a/db/seeds/answer.py b/db/seeds/answer.py index 25ef1751..987ea193 100644 --- a/db/seeds/answer.py +++ b/db/seeds/answer.py @@ -28,7 +28,7 @@ def seed_database(): "My reply is No.", "My sources say No.", "Outlook not so good.", - "Very Doubtful" + "Very Doubtful", ] for answer in initial_answers: diff --git a/db/seeds/channels.py b/db/seeds/channels.py index 619a5ae0..8924f2ac 100644 --- a/db/seeds/channels.py +++ b/db/seeds/channels.py @@ -17,5 +17,6 @@ def seed_database(): } for channel_name in initial_channels: - Channel.create(channel_name=channel_name, channel_id=initial_channels.get(channel_name)) - + Channel.create( + channel_name=channel_name, channel_id=initial_channels.get(channel_name) + ) diff --git a/db/seeds/puns.py b/db/seeds/puns.py index b7d4477f..469bce37 100644 --- a/db/seeds/puns.py +++ b/db/seeds/puns.py @@ -2,28 +2,32 @@ def seed_database(): - pun_specs = [{ - 'text': "What do you call a person who hates hippos because they're so hateful? Hippo-critical.", - 'pun_words': [ - {'word': 'hippo', 'emoji_code': ':hippopotamus:'}, - {'word': 'critical', 'emoji_code': ':thumbs_down:'} - ] - }, { - 'text': "You call a bad discord mod an admin-is-traitor.", - 'pun_words': [ - {'word': 'admin', 'emoji_code': ':hammer:'}, - {'word': 'traitor', 'emoji_code': ':hammer:'} - ] - }, { - 'text': "Games like nerdlegame are a form of mathochism.", - 'pun_words': [ - {'word': 'math', 'emoji_code': ':1234:'}, - {'word': 'masochism', 'emoji_code': ':knife:'} - ] - }] + pun_specs = [ + { + "text": "What do you call a person who hates hippos because they're so hateful? Hippo-critical.", + "pun_words": [ + {"word": "hippo", "emoji_code": ":hippopotamus:"}, + {"word": "critical", "emoji_code": ":thumbs_down:"}, + ], + }, + { + "text": "You call a bad discord mod an admin-is-traitor.", + "pun_words": [ + {"word": "admin", "emoji_code": ":hammer:"}, + {"word": "traitor", "emoji_code": ":hammer:"}, + ], + }, + { + "text": "Games like nerdlegame are a form of mathochism.", + "pun_words": [ + {"word": "math", "emoji_code": ":1234:"}, + {"word": "masochism", "emoji_code": ":knife:"}, + ], + }, + ] for pun_spec in pun_specs: - pun = Pun.create(text=pun_spec['text']) + pun = Pun.create(text=pun_spec["text"]) - for pun_word in pun_spec['pun_words']: - pun.add_pun_word(pun_word['word'], pun_word['emoji_code']) + for pun_word in pun_spec["pun_words"]: + pun.add_pun_word(pun_word["word"], pun_word["emoji_code"]) diff --git a/db/seeds/trigger.py b/db/seeds/trigger.py index 7a83432a..1cf2ebdb 100644 --- a/db/seeds/trigger.py +++ b/db/seeds/trigger.py @@ -2,18 +2,12 @@ def seed_database(): - trigger_words = [ - "linus", - "#linus", - "#torvalds", - "#linustorvalds", - "torvalds" - ] + trigger_words = ["linus", "#linus", "#torvalds", "#linustorvalds", "torvalds"] linus_trigger = Trigger.create( name="Linus", positive_emoji_code=":penguin:", - negative_emoji_code=':pouting_face:', + negative_emoji_code=":pouting_face:", ) for trigger_word in trigger_words: linus_trigger.add_trigger_word(trigger_word) @@ -23,5 +17,3 @@ def seed_database(): positive_emoji_code=":blush:", negative_emoji_code=":cry:", ) - - diff --git a/lib/bidirectional_iterator.py b/lib/bidirectional_iterator.py index 13de2a0a..346a5843 100644 --- a/lib/bidirectional_iterator.py +++ b/lib/bidirectional_iterator.py @@ -1,4 +1,4 @@ -from typing import List, TypeVar, Generic, Iterator, Optional +from typing import Generic, Iterator, List, Optional, TypeVar T = TypeVar("T") @@ -10,6 +10,7 @@ class BidirectionalIterator(Generic[T]): :param collection: An optional collection of items, default to an empty List. :type collection: Optional[List[T]] """ + def __init__(self, collection: Optional[List[T]]): self.__collection: List[T] = collection or [] self.__position: int = 0 diff --git a/lib/config_required.py b/lib/config_required.py index 049d8279..c794a59a 100644 --- a/lib/config_required.py +++ b/lib/config_required.py @@ -1,8 +1,10 @@ -from bot import app from typing import Callable, Optional + from discord.ext import commands from discord.ext.commands import CogMeta, Context, DisabledCommand +from bot import app + class ConfigRequiredError(DisabledCommand): """The base exception type for errors to required config check @@ -10,6 +12,7 @@ class ConfigRequiredError(DisabledCommand): Inherit from `discord.ext.commands.CommandError` and can be handled like other CommandError exception in `on_command_error` """ + pass @@ -21,10 +24,14 @@ class MissingRequiredConfigError(ConfigRequiredError): def __init__(self, section_key: str, value_key: str, message: Optional[str] = None): base_error_message = f"Missing config '{value_key}' in section '{section_key}'" - super().__init__(f"{base_error_message}\n{message}" if message else base_error_message) + super().__init__( + f"{base_error_message}\n{message}" if message else base_error_message + ) -def cog_config_required(section_key: str, value_key: str, message: Optional[str] = None) -> Callable: +def cog_config_required( + section_key: str, value_key: str, message: Optional[str] = None +) -> Callable: """Validates the presences of a given configuration before each invocation of a `discord.ext.commands.Cog` commands :param section_key: @@ -46,10 +53,13 @@ async def _cog_before_invoke(self, _: Context): setattr(cls, "cog_before_invoke", _cog_before_invoke) return cls + return wrapper -def command_config_required(section_key: str, value_key: str, message: Optional[str] = None) -> Callable[[Context], bool]: +def command_config_required( + section_key: str, value_key: str, message: Optional[str] = None +) -> Callable[[Context], bool]: """Validates the presences of a given configuration before running the `discord.ext.commands.Command` @@ -67,4 +77,5 @@ async def predicate(_: Context) -> bool: if not app.config.get(section_key, value_key): raise MissingRequiredConfigError(section_key, value_key, message) return True + return commands.check(predicate) diff --git a/lib/fields.py b/lib/fields.py new file mode 100644 index 00000000..e1a05b93 --- /dev/null +++ b/lib/fields.py @@ -0,0 +1,29 @@ +from typing import Type + +from sqlalchemy import Column +from sqlalchemy.types import Integer, TypeDecorator +from sqlmodel import Field + + +class IntEnumType(TypeDecorator): + impl = Integer + cache_ok = True + + def __init__(self, enumtype): + self.enumtype = enumtype + super().__init__() + + def process_bind_param(self, value, dialect): + if value is None: + return None + if isinstance(value, int): + return value + return value.value + + def process_result_value(self, value, dialect): + return self.enumtype(value) if value is not None else None + + +def EnumField(enum_cls: Type, **kwargs): + sa_column = Column(IntEnumType(enum_cls)) + return Field(sa_column=sa_column, **kwargs) diff --git a/lib/paged_embeds.py b/lib/paged_embeds.py index ab8388ac..0802b3bb 100644 --- a/lib/paged_embeds.py +++ b/lib/paged_embeds.py @@ -1,8 +1,10 @@ -from typing import List, Any, Callable, Optional +from typing import Any, Callable, List, Optional + from discord import Embed, Interaction, Message from discord.ext.commands import Context -from discord.ui import View, Button +from discord.ui import Button, View from emoji.core import emojize + from lib.bidirectional_iterator import BidirectionalIterator @@ -25,8 +27,12 @@ def __init__(self, embeds: List[Embed]): self.__message: Optional[Message] = None self.__embeds: BidirectionalIterator[Embed] = BidirectionalIterator(embeds) self.__arrow_button: List[EmbedButton] = [ - EmbedButton(self.__embeds.previous, emoji=emojize(":left_arrow:"), disabled=True), - EmbedButton(self.__embeds.next, emoji=emojize(":right_arrow:"), disabled=True) + EmbedButton( + self.__embeds.previous, emoji=emojize(":left_arrow:"), disabled=True + ), + EmbedButton( + self.__embeds.next, emoji=emojize(":right_arrow:"), disabled=True + ), ] self.add_item(self.previous_arrow) @@ -60,4 +66,6 @@ async def on_timeout(self): await self.__message.edit(embed=self.__embeds.current, view=self) async def send(self, ctx: Context, ephemeral: bool = True): - self.__message = await ctx.send(embed=self.__embeds.current, view=self, ephemeral=ephemeral) \ No newline at end of file + self.__message = await ctx.send( + embed=self.__embeds.current, view=self, ephemeral=ephemeral + ) diff --git a/lib/timed_view.py b/lib/timed_view.py index fd23fde2..c9c3cdee 100644 --- a/lib/timed_view.py +++ b/lib/timed_view.py @@ -1,6 +1,8 @@ -from asyncio import sleep as async_sleep, create_task, Task +from asyncio import Task, create_task +from asyncio import sleep as async_sleep from datetime import timedelta -from typing import Any, Optional +from typing import Optional + from discord.ui import View @@ -16,7 +18,7 @@ class TimedView(View): def __init__(self, seconds: int = 900): super().__init__(timeout=None) - + self.seconds: int = seconds self.__timer_task: Optional[Task[None]] = None @@ -53,7 +55,9 @@ def remaining_time(self) -> str: def start_timer(self): """Starts the view's timer task""" - self.__timer_task = create_task(self.__impl_timer_task(), name=f"grace-timed-view-timer-{self.id}") + self.__timer_task = create_task( + self.__impl_timer_task(), name=f"grace-timed-view-timer-{self.id}" + ) def cancel_timer(self): """Cancels the view's timer task""" diff --git a/pyproject.toml b/pyproject.toml index eb56957e..846f91c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,48 @@ [build-system] -requires = ["setuptools", "wheel", "Cython"] \ No newline at end of file +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "Grace" +version = "2.0.0" +description = "The Code Society community Bot" +authors = [ + { name = "Code Society Lab" } +] +license = { text = "GNU General Public License v3.0" } +requires-python = ">=3.11" +readme = "README.md" + +dependencies = [ + "grace-framework @ git+https://github.com/Code-Society-Lab/grace-framework.git@main", + "discord-pretty-help==2.0.4", + "emoji>=2.1.0", + "nltk", + "requests", + "pillow", + "geopy", + "pytz", + "tzdata", + "timezonefinder", + "pygithub", + "googletrans==4.0.0-rc1", + "openai==0.26.1", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-asyncio", + "flake8", + "mypy", + "ruff", +] + +[project.urls] +Homepage = "https://github.com/Code-Society-Lab/grace" +Documentation = "https://github.com/Code-Society-Lab/grace/wiki" +"Issue Tracker" = "https://github.com/Code-Society-Lab/grace/issues" +"Discord Server" = "https://discord.gg/code-society-823178343943897088" + +[tool.setuptools.packages.find] +include = ["*"] diff --git a/setup.py b/setup.py deleted file mode 100755 index a6cdbfe3..00000000 --- a/setup.py +++ /dev/null @@ -1,38 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name='Grace', - version='2.0.0', - author='Code Society Lab', - description='The Code Society community Bot', - url="https://github.com/Code-Society-Lab/grace", - project_urls={ - "Documentation": "https://github.com/Code-Society-Lab/grace/wiki", - "Issue tracker": "https://github.com/Code-Society-Lab/grace/issues", - "Discord server": "https://discord.gg/code-society-823178343943897088", - }, - license="GNU General Public License v3.0", - python_requires='>=3.10.0', - - packages=find_packages(), - - include_package_data=True, - install_requires=[ - # For now we always want the latest version on github - 'grace-framework @ git+https://github.com/Code-Society-Lab/grace-framework.git@main', - 'emoji>=2.1.0', - 'nltk', - 'discord-pretty-help==2.0.4', - 'requests', - 'pillow', - 'geopy', - 'pytz', - 'tzdata', - 'timezonefinder', - 'pygithub', - 'googletrans==4.0.0-rc1', - 'openai==0.26.1', - 'apscheduler', - 'pytest-asyncio' - ] -) diff --git a/tests/extensions/test_time_cog.py b/tests/extensions/test_time_cog.py index 4259b803..ea06e3a1 100644 --- a/tests/extensions/test_time_cog.py +++ b/tests/extensions/test_time_cog.py @@ -1,8 +1,9 @@ -import pytest from datetime import datetime from unittest.mock import AsyncMock, MagicMock +import pytest import pytz + from bot.extensions.time_cog import TimeCog diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 1f8b81f7..441cd77e 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -1,7 +1,8 @@ -from bot import app from logging import info -from db.seed import get_seed_modules +from bot import app +from db.seed import get_seed_modules +from grace.database import up_migration app.load("test") @@ -14,6 +15,8 @@ app.create_database() app.create_tables() +up_migration(app, "head") + for seed_module in get_seed_modules(): info(f"Seeding {seed_module.__name__}") seed_module.seed_database() diff --git a/tests/models/test_extension.py b/tests/models/test_extension.py index bb45c917..b892c650 100644 --- a/tests/models/test_extension.py +++ b/tests/models/test_extension.py @@ -4,10 +4,7 @@ def test_create_extension(): """Test creating an extension""" - extension = Extension.create( - module_name="test_extension", - state=State.ENABLED - ) + extension = Extension.create(module_name="test_extension", state=State.ENABLED) assert extension.module_name == "test_extension" assert extension.state == State.ENABLED @@ -15,14 +12,14 @@ def test_create_extension(): def test_get_extension(): """Test getting an extension""" - extension = Extension.get_by(module_name="test_extension") + extension = Extension.find_by(module_name="test_extension") - assert Extension.get(extension.id) == extension + assert Extension.find(extension.id) == extension def test_disable_extension(): """Test disabling an extension""" - extension = Extension.get_by(module_name="test_extension") + extension = Extension.find_by(module_name="test_extension") extension.state = State.DISABLED assert extension.state == State.DISABLED @@ -30,7 +27,7 @@ def test_disable_extension(): def test_enable_extension(): """Test enabling an extension""" - extension = Extension.get_by(module_name="test_extension") + extension = Extension.find_by(module_name="test_extension") extension.state = State.ENABLED assert extension.state == State.ENABLED @@ -45,7 +42,7 @@ def test_get_by_state(): def test_delete_extension(): """Test deleting an extension""" - extension = Extension.get_by(module_name="test_extension") + extension = Extension.find_by(module_name="test_extension") extension.delete() - assert Extension.get(extension.id) is None + assert Extension.find(extension.id) is None