PYTHON
database.py🐍python
"""
SQLite database for persistent storage.
"""
import aiosqlite
from datetime import datetime
from typing import Optional, List
import logging
logger = logging.getLogger(__name__)
class Database:
"""Async SQLite database wrapper."""
def __init__(self, db_path: str):
self.db_path = db_path
self.connection: Optional[aiosqlite.Connection] = None
async def initialize(self):
"""Initialize database and create tables."""
self.connection = await aiosqlite.connect(self.db_path)
await self._create_tables()
logger.info(f"Database initialized: {self.db_path}")
async def _create_tables(self):
"""Create required tables."""
await self.connection.executescript("""
-- Reminders table
CREATE TABLE IF NOT EXISTS reminders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
guild_id INTEGER,
message TEXT NOT NULL,
remind_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed BOOLEAN DEFAULT FALSE
);
-- User stats table
CREATE TABLE IF NOT EXISTS user_stats (
user_id INTEGER NOT NULL,
guild_id INTEGER NOT NULL,
messages INTEGER DEFAULT 0,
commands_used INTEGER DEFAULT 0,
last_active TIMESTAMP,
PRIMARY KEY (user_id, guild_id)
);
-- Polls table
CREATE TABLE IF NOT EXISTS polls (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id INTEGER UNIQUE NOT NULL,
channel_id INTEGER NOT NULL,
guild_id INTEGER NOT NULL,
author_id INTEGER NOT NULL,
question TEXT NOT NULL,
options TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ends_at TIMESTAMP,
active BOOLEAN DEFAULT TRUE
);
-- Guild settings table
CREATE TABLE IF NOT EXISTS guild_settings (
guild_id INTEGER PRIMARY KEY,
prefix TEXT DEFAULT '!',
welcome_channel_id INTEGER,
mod_role_id INTEGER,
mute_role_id INTEGER
);
""")
await self.connection.commit()
async def close(self):
"""Close database connection."""
if self.connection:
await self.connection.close()
logger.info("Database connection closed")
# =========================================================================
# Reminder Methods
# =========================================================================
async def add_reminder(self, user_id: int, channel_id: int, guild_id: int,
message: str, remind_at: datetime) -> int:
"""Add a new reminder."""
cursor = await self.connection.execute(
"""INSERT INTO reminders (user_id, channel_id, guild_id, message, remind_at)
VALUES (?, ?, ?, ?, ?)""",
(user_id, channel_id, guild_id, message, remind_at)
)
await self.connection.commit()
return cursor.lastrowid
async def get_due_reminders(self) -> List[dict]:
"""Get all due reminders."""
cursor = await self.connection.execute(
"""SELECT id, user_id, channel_id, guild_id, message
FROM reminders
WHERE remind_at <= datetime('now') AND completed = FALSE"""
)
rows = await cursor.fetchall()
return [
{"id": r[0], "user_id": r[1], "channel_id": r[2],
"guild_id": r[3], "message": r[4]}
for r in rows
]
async def get_user_reminders(self, user_id: int) -> List[dict]:
"""Get all reminders for a user."""
cursor = await self.connection.execute(
"""SELECT id, message, remind_at FROM reminders
WHERE user_id = ? AND completed = FALSE
ORDER BY remind_at""",
(user_id,)
)
rows = await cursor.fetchall()
return [{"id": r[0], "message": r[1], "remind_at": r[2]} for r in rows]
async def complete_reminder(self, reminder_id: int):
"""Mark reminder as completed."""
await self.connection.execute(
"UPDATE reminders SET completed = TRUE WHERE id = ?",
(reminder_id,)
)
await self.connection.commit()
async def delete_reminder(self, reminder_id: int, user_id: int) -> bool:
"""Delete a reminder (user must own it)."""
cursor = await self.connection.execute(
"DELETE FROM reminders WHERE id = ? AND user_id = ?",
(reminder_id, user_id)
)
await self.connection.commit()
return cursor.rowcount > 0
# =========================================================================
# Stats Methods
# =========================================================================
async def increment_messages(self, user_id: int, guild_id: int):
"""Increment message count for user."""
await self.connection.execute(
"""INSERT INTO user_stats (user_id, guild_id, messages, last_active)
VALUES (?, ?, 1, datetime('now'))
ON CONFLICT(user_id, guild_id) DO UPDATE SET
messages = messages + 1, last_active = datetime('now')""",
(user_id, guild_id)
)
await self.connection.commit()
async def increment_commands(self, user_id: int, guild_id: int):
"""Increment command count for user."""
await self.connection.execute(
"""INSERT INTO user_stats (user_id, guild_id, commands_used, last_active)
VALUES (?, ?, 1, datetime('now'))
ON CONFLICT(user_id, guild_id) DO UPDATE SET
commands_used = commands_used + 1, last_active = datetime('now')""",
(user_id, guild_id)
)
await self.connection.commit()
async def get_user_stats(self, user_id: int, guild_id: int) -> Optional[dict]:
"""Get stats for a user in a guild."""
cursor = await self.connection.execute(
"""SELECT messages, commands_used, last_active FROM user_stats
WHERE user_id = ? AND guild_id = ?""",
(user_id, guild_id)
)
row = await cursor.fetchone()
if row:
return {"messages": row[0], "commands_used": row[1], "last_active": row[2]}
return None
async def get_leaderboard(self, guild_id: int, limit: int = 10) -> List[dict]:
"""Get message leaderboard for guild."""
cursor = await self.connection.execute(
"""SELECT user_id, messages, commands_used FROM user_stats
WHERE guild_id = ?
ORDER BY messages DESC LIMIT ?""",
(guild_id, limit)
)
rows = await cursor.fetchall()
return [
{"user_id": r[0], "messages": r[1], "commands_used": r[2]}
for r in rows
]
# =========================================================================
# Poll Methods
# =========================================================================
async def create_poll(self, message_id: int, channel_id: int, guild_id: int,
author_id: int, question: str, options: str,
ends_at: Optional[datetime] = None) -> int:
"""Create a new poll."""
cursor = await self.connection.execute(
"""INSERT INTO polls (message_id, channel_id, guild_id, author_id,
question, options, ends_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(message_id, channel_id, guild_id, author_id, question, options, ends_at)
)
await self.connection.commit()
return cursor.lastrowid
async def get_poll(self, message_id: int) -> Optional[dict]:
"""Get poll by message ID."""
cursor = await self.connection.execute(
"""SELECT id, question, options, author_id, active FROM polls
WHERE message_id = ?""",
(message_id,)
)
row = await cursor.fetchone()
if row:
return {
"id": row[0], "question": row[1], "options": row[2],
"author_id": row[3], "active": row[4]
}
return None
async def close_poll(self, message_id: int):
"""Close a poll."""
await self.connection.execute(
"UPDATE polls SET active = FALSE WHERE message_id = ?",
(message_id,)
)
await self.connection.commit()
async def get_active_polls(self) -> List[tuple]:
"""Get all active polls."""
cursor = await self.connection.execute(
"""SELECT id, channel_id, message_id, question, options, ends_at
FROM polls WHERE active = TRUE"""
)
return await cursor.fetchall()
async def end_poll(self, poll_id: int):
"""End a poll by ID."""
await self.connection.execute(
"UPDATE polls SET active = FALSE WHERE id = ?",
(poll_id,)
)
await self.connection.commit()
async def execute(self, query: str, params: tuple = ()) -> List[tuple]:
"""Execute a raw query and return results."""
cursor = await self.connection.execute(query, params)
return await cursor.fetchall()
# =========================================================================
# Extended Stats Methods
# =========================================================================
async def update_user_stats(self, user_id: int, guild_id: int, stat_type: str):
"""Update user statistics."""
if stat_type == "messages":
await self.increment_messages(user_id, guild_id)
elif stat_type == "commands":
await self.increment_commands(user_id, guild_id)
async def get_server_total_stat(self, guild_id: int, stat_type: str) -> int:
"""Get total of a stat for a server."""
column = "messages" if stat_type == "messages" else "commands_used"
cursor = await self.connection.execute(
f"SELECT SUM({column}) FROM user_stats WHERE guild_id = ?",
(guild_id,)
)
row = await cursor.fetchone()
return row[0] or 0 if row else 0
async def get_user_global_stats(self, user_id: int) -> dict:
"""Get global stats for a user across all servers."""
cursor = await self.connection.execute(
"""SELECT SUM(messages), SUM(commands_used), COUNT(DISTINCT guild_id)
FROM user_stats WHERE user_id = ?""",
(user_id,)
)
row = await cursor.fetchone()
if row and row[0]:
return {
"total_messages": row[0] or 0,
"total_commands": row[1] or 0,
"server_count": row[2] or 0
}
return None