PYTHONPython

database

real world projects / discord bot / bot

PYTHON
database.py🐍
"""
SQLite database for persistent storage.
"""

import aiosqlite
from datetime import datetime
from typing import Optional, List
import logging

logger = logging.getLogger(__name__)


class Database:
    """Async SQLite database wrapper."""
    
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.connection: Optional[aiosqlite.Connection] = None
    
    async def initialize(self):
        """Initialize database and create tables."""
        self.connection = await aiosqlite.connect(self.db_path)
        await self._create_tables()
        logger.info(f"Database initialized: {self.db_path}")
    
    async def _create_tables(self):
        """Create required tables."""
        await self.connection.executescript("""
            -- Reminders table
            CREATE TABLE IF NOT EXISTS reminders (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                user_id INTEGER NOT NULL,
                channel_id INTEGER NOT NULL,
                guild_id INTEGER,
                message TEXT NOT NULL,
                remind_at TIMESTAMP NOT NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                completed BOOLEAN DEFAULT FALSE
            );
            
            -- User stats table
            CREATE TABLE IF NOT EXISTS user_stats (
                user_id INTEGER NOT NULL,
                guild_id INTEGER NOT NULL,
                messages INTEGER DEFAULT 0,
                commands_used INTEGER DEFAULT 0,
                last_active TIMESTAMP,
                PRIMARY KEY (user_id, guild_id)
            );
            
            -- Polls table
            CREATE TABLE IF NOT EXISTS polls (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                message_id INTEGER UNIQUE NOT NULL,
                channel_id INTEGER NOT NULL,
                guild_id INTEGER NOT NULL,
                author_id INTEGER NOT NULL,
                question TEXT NOT NULL,
                options TEXT NOT NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                ends_at TIMESTAMP,
                active BOOLEAN DEFAULT TRUE
            );
            
            -- Guild settings table
            CREATE TABLE IF NOT EXISTS guild_settings (
                guild_id INTEGER PRIMARY KEY,
                prefix TEXT DEFAULT '!',
                welcome_channel_id INTEGER,
                mod_role_id INTEGER,
                mute_role_id INTEGER
            );
        """)
        await self.connection.commit()
    
    async def close(self):
        """Close database connection."""
        if self.connection:
            await self.connection.close()
            logger.info("Database connection closed")
    
    # =========================================================================
    # Reminder Methods
    # =========================================================================
    
    async def add_reminder(self, user_id: int, channel_id: int, guild_id: int,
                           message: str, remind_at: datetime) -> int:
        """Add a new reminder."""
        cursor = await self.connection.execute(
            """INSERT INTO reminders (user_id, channel_id, guild_id, message, remind_at)
               VALUES (?, ?, ?, ?, ?)""",
            (user_id, channel_id, guild_id, message, remind_at)
        )
        await self.connection.commit()
        return cursor.lastrowid
    
    async def get_due_reminders(self) -> List[dict]:
        """Get all due reminders."""
        cursor = await self.connection.execute(
            """SELECT id, user_id, channel_id, guild_id, message
               FROM reminders 
               WHERE remind_at <= datetime('now') AND completed = FALSE"""
        )
        rows = await cursor.fetchall()
        return [
            {"id": r[0], "user_id": r[1], "channel_id": r[2], 
             "guild_id": r[3], "message": r[4]}
            for r in rows
        ]
    
    async def get_user_reminders(self, user_id: int) -> List[dict]:
        """Get all reminders for a user."""
        cursor = await self.connection.execute(
            """SELECT id, message, remind_at FROM reminders 
               WHERE user_id = ? AND completed = FALSE
               ORDER BY remind_at""",
            (user_id,)
        )
        rows = await cursor.fetchall()
        return [{"id": r[0], "message": r[1], "remind_at": r[2]} for r in rows]
    
    async def complete_reminder(self, reminder_id: int):
        """Mark reminder as completed."""
        await self.connection.execute(
            "UPDATE reminders SET completed = TRUE WHERE id = ?",
            (reminder_id,)
        )
        await self.connection.commit()
    
    async def delete_reminder(self, reminder_id: int, user_id: int) -> bool:
        """Delete a reminder (user must own it)."""
        cursor = await self.connection.execute(
            "DELETE FROM reminders WHERE id = ? AND user_id = ?",
            (reminder_id, user_id)
        )
        await self.connection.commit()
        return cursor.rowcount > 0
    
    # =========================================================================
    # Stats Methods
    # =========================================================================
    
    async def increment_messages(self, user_id: int, guild_id: int):
        """Increment message count for user."""
        await self.connection.execute(
            """INSERT INTO user_stats (user_id, guild_id, messages, last_active)
               VALUES (?, ?, 1, datetime('now'))
               ON CONFLICT(user_id, guild_id) DO UPDATE SET
               messages = messages + 1, last_active = datetime('now')""",
            (user_id, guild_id)
        )
        await self.connection.commit()
    
    async def increment_commands(self, user_id: int, guild_id: int):
        """Increment command count for user."""
        await self.connection.execute(
            """INSERT INTO user_stats (user_id, guild_id, commands_used, last_active)
               VALUES (?, ?, 1, datetime('now'))
               ON CONFLICT(user_id, guild_id) DO UPDATE SET
               commands_used = commands_used + 1, last_active = datetime('now')""",
            (user_id, guild_id)
        )
        await self.connection.commit()
    
    async def get_user_stats(self, user_id: int, guild_id: int) -> Optional[dict]:
        """Get stats for a user in a guild."""
        cursor = await self.connection.execute(
            """SELECT messages, commands_used, last_active FROM user_stats
               WHERE user_id = ? AND guild_id = ?""",
            (user_id, guild_id)
        )
        row = await cursor.fetchone()
        if row:
            return {"messages": row[0], "commands_used": row[1], "last_active": row[2]}
        return None
    
    async def get_leaderboard(self, guild_id: int, limit: int = 10) -> List[dict]:
        """Get message leaderboard for guild."""
        cursor = await self.connection.execute(
            """SELECT user_id, messages, commands_used FROM user_stats
               WHERE guild_id = ?
               ORDER BY messages DESC LIMIT ?""",
            (guild_id, limit)
        )
        rows = await cursor.fetchall()
        return [
            {"user_id": r[0], "messages": r[1], "commands_used": r[2]}
            for r in rows
        ]
    
    # =========================================================================
    # Poll Methods
    # =========================================================================
    
    async def create_poll(self, message_id: int, channel_id: int, guild_id: int,
                          author_id: int, question: str, options: str,
                          ends_at: Optional[datetime] = None) -> int:
        """Create a new poll."""
        cursor = await self.connection.execute(
            """INSERT INTO polls (message_id, channel_id, guild_id, author_id, 
                                  question, options, ends_at)
               VALUES (?, ?, ?, ?, ?, ?, ?)""",
            (message_id, channel_id, guild_id, author_id, question, options, ends_at)
        )
        await self.connection.commit()
        return cursor.lastrowid
    
    async def get_poll(self, message_id: int) -> Optional[dict]:
        """Get poll by message ID."""
        cursor = await self.connection.execute(
            """SELECT id, question, options, author_id, active FROM polls
               WHERE message_id = ?""",
            (message_id,)
        )
        row = await cursor.fetchone()
        if row:
            return {
                "id": row[0], "question": row[1], "options": row[2],
                "author_id": row[3], "active": row[4]
            }
        return None
    
    async def close_poll(self, message_id: int):
        """Close a poll."""
        await self.connection.execute(
            "UPDATE polls SET active = FALSE WHERE message_id = ?",
            (message_id,)
        )
        await self.connection.commit()
    
    async def get_active_polls(self) -> List[tuple]:
        """Get all active polls."""
        cursor = await self.connection.execute(
            """SELECT id, channel_id, message_id, question, options, ends_at
               FROM polls WHERE active = TRUE"""
        )
        return await cursor.fetchall()
    
    async def end_poll(self, poll_id: int):
        """End a poll by ID."""
        await self.connection.execute(
            "UPDATE polls SET active = FALSE WHERE id = ?",
            (poll_id,)
        )
        await self.connection.commit()
    
    async def execute(self, query: str, params: tuple = ()) -> List[tuple]:
        """Execute a raw query and return results."""
        cursor = await self.connection.execute(query, params)
        return await cursor.fetchall()
    
    # =========================================================================
    # Extended Stats Methods
    # =========================================================================
    
    async def update_user_stats(self, user_id: int, guild_id: int, stat_type: str):
        """Update user statistics."""
        if stat_type == "messages":
            await self.increment_messages(user_id, guild_id)
        elif stat_type == "commands":
            await self.increment_commands(user_id, guild_id)
    
    async def get_server_total_stat(self, guild_id: int, stat_type: str) -> int:
        """Get total of a stat for a server."""
        column = "messages" if stat_type == "messages" else "commands_used"
        cursor = await self.connection.execute(
            f"SELECT SUM({column}) FROM user_stats WHERE guild_id = ?",
            (guild_id,)
        )
        row = await cursor.fetchone()
        return row[0] or 0 if row else 0
    
    async def get_user_global_stats(self, user_id: int) -> dict:
        """Get global stats for a user across all servers."""
        cursor = await self.connection.execute(
            """SELECT SUM(messages), SUM(commands_used), COUNT(DISTINCT guild_id)
               FROM user_stats WHERE user_id = ?""",
            (user_id,)
        )
        row = await cursor.fetchone()
        if row and row[0]:
            return {
                "total_messages": row[0] or 0,
                "total_commands": row[1] or 0,
                "server_count": row[2] or 0
            }
        return None
PreviousNext