Skip to content

Commit

Permalink
Adds support for player changes over the API
Browse files Browse the repository at this point in the history
* Refactors rate limiter and creates middleware for rate limiting
* Adds socket rooms for voice channel emissions
* Refactors player utils to support non-interaction invocations
* Adds environment variables PLAYER_USER_LIMIT and PLAYER_GUILD_LIMIT
* Adds several API endpoints to control the player
  • Loading branch information
mikeyaworski committed Oct 7, 2023
1 parent 59c1c3b commit 1a4e0b0
Show file tree
Hide file tree
Showing 22 changed files with 620 additions and 104 deletions.
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ CHATGPT_GUILD_LIMIT=10,60
# in seconds
CHATGPT_CONVERSATION_TIME_LIMIT=600

# Player
# 1 request every 2 seconds
PLAYER_USER_LIMIT=1,2
# 1 request per second
PLAYER_GUILD_LIMIT=1,1

# Slash Commands in development. Remove this during deployment.
SLASH_COMMANDS_GUILD_ID=...

Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ CHATGPT_GUILD_LIMIT=10,60
# in seconds
CHATGPT_CONVERSATION_TIME_LIMIT=10
# Player
# 1 request every 2 seconds
PLAYER_USER_LIMIT=1,2
# 1 request per second
PLAYER_GUILD_LIMIT=1,1
# Create your own webhook secret if you intend to use the webhook API routes and want them protected
WEBHOOK_SECRET=...
Expand Down
2 changes: 2 additions & 0 deletions deploy/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ services:
OPENAI_SECRET_KEY: ${OPENAI_SECRET_KEY}
CHATGPT_USER_LIMIT: ${CHATGPT_USER_LIMIT}
CHATGPT_GUILD_LIMIT: ${CHATGPT_GUILD_LIMIT}
PLAYER_USER_LIMIT: ${PLAYER_USER_LIMIT}
PLAYER_GUILD_LIMIT: ${PLAYER_GUILD_LIMIT}
CHATGPT_CONVERSATION_TIME_LIMIT: ${CHATGPT_CONVERSATION_TIME_LIMIT}
SLASH_COMMANDS_GUILD_ID: ${SLASH_COMMANDS_GUILD_ID}
WEBHOOK_SECRET: ${WEBHOOK_SECRET}
Expand Down
2 changes: 2 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import remindersRouter from 'src/api/routes/reminders';
import guildsRouter from 'src/api/routes/guilds';
import dmsRouter from 'src/api/routes/dms';
import chessRouter from 'src/api/routes/chess';
import playerRouter from 'src/api/routes/player';
import webhooksRouter from 'src/api/routes/webhooks';

import { WAKE_INTERVAL } from 'src/constants';
Expand Down Expand Up @@ -61,6 +62,7 @@ export function initApi(): void {
app.use('/guilds', guildsRouter);
app.use('/dms', dmsRouter);
app.use('/chess', chessRouter);
app.use('/player', playerRouter);
app.use('/webhooks', webhooksRouter);
const port = process.env.PORT ? Number(process.env.PORT) : 3000;
httpServer.listen(port, () => {
Expand Down
12 changes: 12 additions & 0 deletions src/api/middlewares/rate-limiter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { Response, NextFunction } from 'express';
import { AuthRequest } from 'src/api/middlewares/auth';
import { RateLimiter } from 'src/types';

export const getRateLimiterMiddleware = (rateLimiter: RateLimiter) => async (req: AuthRequest, res: Response, next: NextFunction): Promise<void> => {
try {
await rateLimiter.attempt({ userId: req.user.id, guildId: req.params.guild_id });
next();
} catch (err) {
res.status(401).send((err as Error).message);
}
};
253 changes: 253 additions & 0 deletions src/api/routes/player.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import express, { Response, NextFunction } from 'express';
import authMiddleware, { AuthRequest } from 'src/api/middlewares/auth';
import Session from 'src/commands/player/session';
import sessions from 'src/commands/player/sessions';
import { checkVoiceErrors, getErrorMsg, getRateLimiterFromEnv } from 'src/discord-utils';
import { getRateLimiterMiddleware } from 'src/api/middlewares/rate-limiter';
import { PlayerFavorites } from 'src/models/player-favorites';
import { play } from 'src/commands/player/play';
import { guildMiddleware } from '../middlewares/guild';

const rateLimiter = getRateLimiterFromEnv('PLAYER_USER_LIMIT', 'PLAYER_GUILD_LIMIT');
const rateLimiterMiddleware = getRateLimiterMiddleware(rateLimiter);

const router = express.Router();

type SessionRequest = AuthRequest & {
playerSession: Session,
}

async function sessionPermissionMiddleware(req: AuthRequest, res: Response, next: NextFunction) {
const session = sessions.get(req.params.guildId);
if (session) {
try {
await checkVoiceErrors({
userId: req.user.id,
guildId: req.params.guildId,
});
} catch (err) {
return res.status(401).send((err as Error).message);
}
}
return next();
}

async function sessionMiddleware(req: AuthRequest, res: Response, next: NextFunction) {
const session = sessions.get(req.params.guildId);
if (!session) {
return res.status(404).send('Player session not active on this guild');
}
// @ts-expect-error
req.playerSession = session;
return next();
}

// @ts-expect-error
router.get('/:guildId', authMiddleware, sessionPermissionMiddleware, sessionMiddleware, async (req: SessionRequest, res) => {
const data = await req.playerSession.getPlayerStatus();
res.status(200).json(data);
});

router.post(
'/:guildId/resume',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
if (req.playerSession.resume()) {
res.status(204).end();
} else {
res.status(400).end();
}
},
);

router.post(
'/:guildId/pause',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
if (req.playerSession.pause()) {
res.status(204).end();
} else {
res.status(400).end();
}
},
);

router.post(
'/:guildId/skip',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
await req.playerSession.skip();
res.status(204).end();
},
);

router.post(
'/:guildId/queue/shuffle',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
req.playerSession.shuffle();
res.status(204).end();
},
);

router.post(
'/:guildId/queue/unshuffle',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
req.playerSession.unshuffle();
res.status(204).end();
},
);

router.post(
'/:guildId/queue/loop',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
req.playerSession.loop();
res.status(204).end();
},
);

router.post(
'/:guildId/queue/unloop',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
req.playerSession.unloop();
res.status(204).end();
},
);

router.post(
'/:guildId/queue/move',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
if (req.body.from == null || req.body.to == null) {
res.status(400).send('to and from indices are required.');
} else {
req.playerSession.move(req.body.from, req.body.to);
res.status(204).end();
}
},
);

router.post(
'/:guildId/queue/clear',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
req.playerSession.clear();
res.status(204).end();
},
);

router.post(
'/:guildId/queue/:trackId/remove',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
req.playerSession.remove(req.params.trackId);
res.status(204).end();
},
);

router.post(
'/:guildId/queue/:trackId/play_immediately',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
sessionMiddleware,
async (req: SessionRequest, res) => {
const idx = req.playerSession.queue.findIndex(track => track.id === req.params.trackId);
if (idx < 0) {
return res.status(404).end();
}
req.playerSession.move(idx, 0);
await req.playerSession.skip();
return res.status(204).end();
},
);

router.get(
'/:guildId/favorites',
authMiddleware,
// @ts-expect-error
guildMiddleware,
async (req: AuthRequest, res) => {
const favorites = await PlayerFavorites.findAll({
where: {
guild_id: req.params.guildId,
},
});
res.status(200).json(favorites);
},
);

router.post(
'/:guildId/play',
authMiddleware,
// @ts-expect-error
rateLimiterMiddleware,
sessionPermissionMiddleware,
async (req: SessionRequest, res) => {
try {
await play({
invoker: {
userId: req.user.id,
guildId: req.params.guildId,
},
inputs: {
vodLink: req.body.vodLink,
favoriteId: req.body.favoriteId,
streamLink: req.body.streamLink,
queryStr: req.body.queryStr,
pushToFront: req.body.pushToFront,
shuffle: req.body.shuffle,
},
});
res.status(204).end();
} catch (err) {
res.status(400).send(getErrorMsg(err));
}
},
);

export default router;
9 changes: 6 additions & 3 deletions src/api/sockets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { SocketEvent } from 'src/types/sockets';
import { error } from 'src/logging';
import { getUserFromAuthToken } from 'src/api/middlewares/auth';
import { client } from 'src/client';
import { isGuildChannel, userCanManageChannel, userCanViewChannel } from 'src/discord-utils';
import { isGuildChannel, checkUserCanManageChannel, checkUserCanViewChannel, checkUserCanConnectToChannel } from 'src/discord-utils';
import { socketIoServer } from 'src/api';
import { Reminder } from 'src/models/reminders';

Expand Down Expand Up @@ -39,12 +39,15 @@ socketIoServer.on('connection', async socket => {
}).catch(error);
client.channels.cache.forEach(async channel => {
if (!isGuildChannel(channel)) return;
userCanViewChannel({ userId: user.id, channelId: channel.id }).then(canViewChannel => {
checkUserCanViewChannel({ userId: user.id, channelId: channel.id }).then(canViewChannel => {
if (canViewChannel) socket.join(`${channel.guildId}_${channel.id}_VIEW`);
});
userCanManageChannel({ userId: user.id, channelId: channel.id }).then(canManageChannel => {
checkUserCanManageChannel({ userId: user.id, channelId: channel.id }).then(canManageChannel => {
if (canManageChannel) socket.join(`${channel.guildId}_${channel.id}_MANAGE`);
});
checkUserCanConnectToChannel({ userId: user.id, channelId: channel.id }).then(canConnectToChannel => {
if (canConnectToChannel) socket.join(`${channel.guildId}_${channel.id}_CONNECT`);
});
});
} catch (err) {
error('Error during socket connection', err);
Expand Down
5 changes: 4 additions & 1 deletion src/commands/player/leave.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ const LeaveCommand: Command = {
const guild = interaction.guild!;

const botMember = await guild.members.fetch(client.user!.id);
const isInSameChannelAsBot = await getIsInSameChannelAsBot(interaction);
const isInSameChannelAsBot = await getIsInSameChannelAsBot({
userId: interaction.user.id,
guildId: guild.id,
});

if (!botMember.voice.channel) {
await interaction.editReply('Bot is not connected to a voice channel.');
Expand Down
4 changes: 2 additions & 2 deletions src/commands/player/pause.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Command, ContextMenuTypes } from 'src/types';

import { SlashCommandBuilder } from '@discordjs/builders';
import { checkVoiceErrors } from 'src/discord-utils';
import { checkVoiceErrorsByInteraction } from 'src/discord-utils';
import { attachPlayerButtons } from './utils';
import sessions from './sessions';

Expand Down Expand Up @@ -45,7 +45,7 @@ const NowPlayingCommand: Command = {
});
return;
}
await checkVoiceErrors(interaction);
await checkVoiceErrorsByInteraction(interaction);

const success = session.pause();
await interaction.editReply({
Expand Down
Loading

0 comments on commit 1a4e0b0

Please sign in to comment.