Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support user app installation context #80

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from datetime import datetime, timezone
from os import path, listdir
from discord.ext.commands import AutoShardedBot, Context
from discord import Activity, AllowedMentions, Intents
from discord import Activity, AllowedMentions, Intents, Interaction
from aiohttp import ClientSession, ClientTimeout
from discord.ext.commands.bot import when_mentioned_or

from cogs.utils.runner import Runner

class PistonBot(AutoShardedBot):
def __init__(self, *args, **options):
Expand All @@ -29,6 +29,7 @@ def __init__(self, *args, **options):

async def start(self, *args, **kwargs):
self.session = ClientSession(timeout=ClientTimeout(total=15))
self.runner = Runner(self.config['emkc_key'], self.session)
await super().start(*args, **kwargs)

async def close(self):
Expand All @@ -52,7 +53,6 @@ async def setup_hook(self):
exc = f'{type(e).__name__}: {e}'
print(f'Failed to load extension {extension}\n{exc}')


def user_is_admin(self, user):
return user.id in self.config['admins']

Expand Down
4 changes: 2 additions & 2 deletions src/cogs/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from asyncio import TimeoutError as AsyncTimeoutError
from discord import Embed, DMChannel, errors as discord_errors
from discord.ext import commands
from .utils.errors import PistonError
from .utils.errors import PistonError, NoLanguageFoundError


class ErrorHandler(commands.Cog, name='ErrorHandler'):
Expand Down Expand Up @@ -63,7 +63,7 @@ async def on_command_error(self, ctx, error):
await ctx.send(f'Sorry {usr}, you are not allowed to run this command.')
return

if isinstance(error, commands.BadArgument):
if isinstance(error, commands.BadArgument) or isinstance(error, NoLanguageFoundError):
# It's in an embed to prevent mentions from working
embed = Embed(
title='Error',
Expand Down
8 changes: 8 additions & 0 deletions src/cogs/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ async def maintenance(self, ctx):
self.client.maintenance_mode = True
await self.client.change_presence(activity=self.client.maintenance_activity)

# ----------------------------------------------
# Command to sync slash commands
# ----------------------------------------------
@commands.command(name='synccmds', hidden=True)
async def sync_commands(self, ctx: commands.Context):
await self.client.tree.sync()
await ctx.send('Commands synced.')


async def setup(client):
await client.add_cog(Management(client))
260 changes: 41 additions & 219 deletions src/cogs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@

"""
# pylint: disable=E0402
import json
import re, sys
import sys
from dataclasses import dataclass
from discord import Embed, Message, errors as discord_errors
from discord.ext import commands, tasks
from discord.utils import escape_mentions
from aiohttp import ContentTypeError
from .utils.codeswap import add_boilerplate
from .utils.errors import PistonInvalidContentType, PistonInvalidStatus, PistonNoOutput
#pylint: disable=E1101
from discord.ext import commands
# pylint: disable=E1101


@dataclass
Expand Down Expand Up @@ -46,218 +41,45 @@ def get_size(obj, seen=None):
class Run(commands.Cog, name='CodeExecution'):
def __init__(self, client):
self.client = client
self.run_IO_store = dict() # Store the most recent /run message for each user.id
self.languages = dict() # Store the supported languages and aliases
self.versions = dict() # Store version for each language
self.run_regex_code = re.compile(
r'(?s)/(?:edit_last_)?run'
r'(?: +(?P<language>\S*?)\s*|\s*)'
r'(?:-> *(?P<output_syntax>\S*)\s*|\s*)'
r'(?:\n(?P<args>(?:[^\n\r\f\v]*\n)*?)\s*|\s*)'
r'```(?:(?P<syntax>\S+)\n\s*|\s*)(?P<source>.*)```'
r'(?:\n?(?P<stdin>(?:[^\n\r\f\v]\n?)+)+|)'
)
self.run_regex_file = re.compile(
r'/run(?: *(?P<language>\S*)\s*?|\s*?)?'
r'(?: *-> *(?P<output>\S*)\s*?|\s*?)?'
r'(?:\n(?P<args>(?:[^\n\r\f\v]+\n?)*)\s*|\s*)?'
r'(?:\n*(?P<stdin>(?:[^\n\r\f\v]\n*)+)+|)?'
)
self.get_available_languages.start()

@tasks.loop(count=1)
async def get_available_languages(self):
async with self.client.session.get(
'https://emkc.org/api/v2/piston/runtimes'
) as response:
runtimes = await response.json()
for runtime in runtimes:
language = runtime['language']
self.languages[language] = language
self.versions[language] = runtime['version']
for alias in runtime['aliases']:
self.languages[alias] = language
self.versions[alias] = runtime['version']

async def send_to_log(self, ctx, language, source):
logging_data = {
'server': ctx.guild.name if ctx.guild else 'DMChannel',
'server_id': str(ctx.guild.id) if ctx.guild else '0',
'user': f'{ctx.author.name}#{ctx.author.discriminator}',
'user_id': str(ctx.author.id),
'language': language,
'source': source
}
headers = {'Authorization': self.client.config["emkc_key"]}

async with self.client.session.post(
'https://emkc.org/api/internal/piston/log',
headers=headers,
data=json.dumps(logging_data)
) as response:
if response.status != 200:
await self.client.log_error(
commands.CommandError(f'Error sending log. Status: {response.status}'),
ctx
)
return False

return True

async def get_api_parameters_with_codeblock(self, ctx):
if ctx.message.content.count('```') != 2:
raise commands.BadArgument('Invalid command format (missing codeblock?)')

match = self.run_regex_code.search(ctx.message.content)

if not match:
raise commands.BadArgument('Invalid command format')

language, output_syntax, args, syntax, source, stdin = match.groups()

if not language:
language = syntax

if language:
language = language.lower()

if language not in self.languages:
raise commands.BadArgument(
f'Unsupported language: **{str(language)[:1000]}**\n'
'[Request a new language](https://github.com/engineer-man/piston/issues)'
)

return language, output_syntax, source, args, stdin

async def get_api_parameters_with_file(self, ctx):
if len(ctx.message.attachments) != 1:
raise commands.BadArgument('Invalid number of attachments')

file = ctx.message.attachments[0]

MAX_BYTES = 65535
if file.size > MAX_BYTES:
raise commands.BadArgument(f'Source file is too big ({file.size}>{MAX_BYTES})')

filename_split = file.filename.split('.')

if len(filename_split) < 2:
raise commands.BadArgument('Please provide a source file with a file extension')

match = self.run_regex_file.search(ctx.message.content)

if not match:
raise commands.BadArgument('Invalid command format')

language, output_syntax, args, stdin = match.groups()

if not language:
language = filename_split[-1]

if language:
language = language.lower()
self.run_IO_store: dict[int, RunIO] = dict()
# Store the most recent /run message for each user.id

if language not in self.languages:
raise commands.BadArgument(
f'Unsupported file extension: **{language}**\n'
'[Request a new language](https://github.com/engineer-man/piston/issues)'
)

source = await file.read()
try:
source = source.decode('utf-8')
except UnicodeDecodeError as e:
raise commands.BadArgument(str(e))

return language, output_syntax, source, args, stdin

async def get_run_output(self, ctx):
async def get_run_output(self, ctx: commands.Context):
# Get parameters to call api depending on how the command was called (file <> codeblock)
if ctx.message.attachments:
alias, output_syntax, source, args, stdin = await self.get_api_parameters_with_file(ctx)
else:
alias, output_syntax, source, args, stdin = await self.get_api_parameters_with_codeblock(ctx)

# Resolve aliases for language
language = self.languages[alias]

version = self.versions[alias]

# Add boilerplate code to supported languages
source = add_boilerplate(language, source)

# Split args at newlines
if args:
args = [arg for arg in args.strip().split('\n') if arg]

if not source:
raise commands.BadArgument(f'No source code found')

# Call piston API
data = {
'language': alias,
'version': version,
'files': [{'content': source}],
'args': args,
'stdin': stdin or "",
'log': 0
}
headers = {'Authorization': self.client.config["emkc_key"]}
async with self.client.session.post(
'https://emkc.org/api/v2/piston/execute',
headers=headers,
json=data
) as response:
try:
r = await response.json()
except ContentTypeError:
raise PistonInvalidContentType('invalid content type')
if not response.status == 200:
raise PistonInvalidStatus(f'status {response.status}: {r.get("message", "")}')

comp_stderr = r['compile']['stderr'] if 'compile' in r else ''
run = r['run']

if run['output'] is None:
raise PistonNoOutput('no output')

# Logging
await self.send_to_log(ctx, language, source)

language_info=f'{alias}({version})'

# Return early if no output was received
if len(run['output'] + comp_stderr) == 0:
return f'Your {language_info} code ran without output {ctx.author.mention}'

# Limit output to 30 lines maximum
output = '\n'.join((comp_stderr + run['output']).split('\n')[:30])

# Prevent mentions in the code output
output = escape_mentions(output)

# Prevent code block escaping by adding zero width spaces to backticks
output = output.replace("`", "`\u200b")

# Truncate output to be below 2000 char discord limit.
if len(comp_stderr) > 0:
introduction = f'{ctx.author.mention} I received {language_info} compile errors\n'
elif len(run['stdout']) == 0 and len(run['stderr']) > 0:
introduction = f'{ctx.author.mention} I only received {language_info} error output\n'
else:
introduction = f'Here is your {language_info} output {ctx.author.mention}\n'
truncate_indicator = '[...]'
len_codeblock = 7 # 3 Backticks + newline + 3 Backticks
available_chars = 2000-len(introduction)-len_codeblock
if len(output) > available_chars:
output = output[:available_chars-len(truncate_indicator)] + truncate_indicator
source, language, output_syntax, args, stdin = self.client.runner.get_api_params_with_file(
input_language="",
output_syntax="",
args="",
stdin="",
content=ctx.message.content,
file=ctx.message.attachments[0],
)
return await self.client.runner.get_run_output(
ctx.guild,
ctx.author,
source=source,
language=language,
output_syntax=output_syntax,
args=args,
stdin=stdin,
mention_author=True,
)

# Use an empty string if no output language is selected
return (
introduction
+ f'```{output_syntax or ""}\n'
+ output.replace('\0', '')
+ '```'
source, language, output_syntax, args, stdin = self.client.runner.get_api_params_with_codeblock(
content=ctx.message.content,
mention_author=True,
needs_strict_re=True,
)
return await self.client.runner.get_run_output(
ctx.guild,
ctx.author,
source=source,
language=language,
output_syntax=output_syntax,
args=args,
stdin=stdin,
mention_author=True,
)

async def delete_last_output(self, user_id):
Expand Down Expand Up @@ -300,7 +122,7 @@ async def run(self, ctx, *, source=None):
await self.send_howto(ctx)
return
try:
run_output = await self.get_run_output(ctx)
run_output, _ = await self.get_run_output(ctx)
msg = await ctx.send(run_output)
except commands.BadArgument as error:
embed = Embed(
Expand All @@ -320,7 +142,7 @@ async def edit_last_run(self, ctx, *, content=None):
return
try:
msg_to_edit = self.run_IO_store[ctx.author.id].output
run_output = await self.get_run_output(ctx)
run_output, _ = await self.get_run_output(ctx)
await msg_to_edit.edit(content=run_output, embed=None)
except KeyError:
# Message no longer exists in output store
Expand Down Expand Up @@ -388,7 +210,7 @@ async def on_message_delete(self, message):
await self.delete_last_output(message.author.id)

async def send_howto(self, ctx):
languages = sorted(set(self.languages.values()))
languages = self.client.runner.get_languages()

run_instructions = (
'**Update: Discord changed their client to prevent sending messages**\n'
Expand Down
Loading