"""Async PostgreSQL database layer using asyncpg.""" from __future__ import annotations import json import logging from datetime import datetime, timezone from typing import Any import asyncpg from .models import Account, CollectionState, Mention, Post logger = logging.getLogger(__name__) class Database: def __init__(self, dsn: str): self._dsn = dsn self._pool: asyncpg.Pool | None = None async def connect(self) -> None: self._pool = await asyncpg.create_pool(self._dsn, min_size=2, max_size=5) logger.info("Database connection pool created") async def close(self) -> None: if self._pool: await self._pool.close() logger.info("Database connection pool closed") # ── Account operations ────────────────────────────────────────────── async def upsert_account(self, account: Account) -> None: await self._pool.execute( """ INSERT INTO accounts (did, handle, display_name) VALUES ($1, $2, $3) ON CONFLICT (did) DO UPDATE SET handle = EXCLUDED.handle, display_name = EXCLUDED.display_name """, account.did, account.handle, account.display_name, ) async def deactivate_removed_accounts(self, active_dids: set[str]) -> None: """Set active=false for accounts no longer in the config.""" if not active_dids: return await self._pool.execute( """ UPDATE accounts SET active = false WHERE did != ALL($1::text[]) AND active = true """, list(active_dids), ) async def update_account_last_feed(self, did: str) -> None: await self._pool.execute( "UPDATE accounts SET last_feed_collected = now() WHERE did = $1", did ) async def update_account_last_mention(self, did: str) -> None: await self._pool.execute( "UPDATE accounts SET last_mention_collected = now() WHERE did = $1", did ) # ── Post operations ───────────────────────────────────────────────── async def upsert_posts(self, posts: list[Post]) -> int: """Batch upsert posts. Returns the number of rows affected.""" if not posts: return 0 count = 0 async with self._pool.acquire() as conn: async with conn.transaction(): for p in posts: result = await conn.execute( """ INSERT INTO posts ( uri, cid, author_did, text, created_at, indexed_at, reply_parent, reply_root, post_type, has_media, has_embed, like_count, reply_count, repost_count, quote_count, langs, raw_json ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 ) ON CONFLICT (uri) DO UPDATE SET cid = EXCLUDED.cid, like_count = EXCLUDED.like_count, reply_count = EXCLUDED.reply_count, repost_count = EXCLUDED.repost_count, quote_count = EXCLUDED.quote_count, collected_at = now() """, p.uri, p.cid, p.author_did, p.text, p.created_at, p.indexed_at, p.reply_parent, p.reply_root, p.post_type, p.has_media, p.has_embed, p.like_count, p.reply_count, p.repost_count, p.quote_count, p.langs, json.dumps(p.raw_json), ) # asyncpg returns e.g. "INSERT 0 1" count += 1 return count # ── Mention operations ────────────────────────────────────────────── async def upsert_mentions(self, mentions: list[Mention]) -> int: if not mentions: return 0 count = 0 async with self._pool.acquire() as conn: async with conn.transaction(): for m in mentions: result = await conn.execute( """ INSERT INTO mentions ( post_uri, mentioned_did, mentioning_did, post_text, post_created_at, raw_json ) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (post_uri, mentioned_did) DO NOTHING """, m.post_uri, m.mentioned_did, m.mentioning_did, m.post_text, m.post_created_at, json.dumps(m.raw_json), ) if "INSERT 0 1" in result: count += 1 return count # ── Collection state ──────────────────────────────────────────────── async def get_collection_state( self, account_did: str, collection_type: str ) -> CollectionState | None: row = await self._pool.fetchrow( """ SELECT account_did, collection_type, last_post_at FROM collection_state WHERE account_did = $1 AND collection_type = $2 """, account_did, collection_type, ) if not row: return None return CollectionState( account_did=row["account_did"], collection_type=row["collection_type"], last_post_at=row["last_post_at"], ) async def save_collection_state( self, account_did: str, collection_type: str, last_post_at: datetime | None ) -> None: await self._pool.execute( """ INSERT INTO collection_state (account_did, collection_type, last_post_at, updated_at) VALUES ($1, $2, $3, now()) ON CONFLICT (account_did, collection_type) DO UPDATE SET last_post_at = EXCLUDED.last_post_at, updated_at = now() """, account_did, collection_type, last_post_at, ) # ── Collection run tracking ───────────────────────────────────────── async def start_run(self, accounts_total: int) -> int: row = await self._pool.fetchrow( """ INSERT INTO collection_runs (accounts_total) VALUES ($1) RETURNING id """, accounts_total, ) return row["id"] async def update_run_progress( self, run_id: int, *, accounts_done: int | None = None, posts_collected: int | None = None, mentions_collected: int | None = None, ) -> None: parts = [] args: list[Any] = [] idx = 1 if accounts_done is not None: idx += 1 parts.append(f"accounts_done = ${idx}") args.append(accounts_done) if posts_collected is not None: idx += 1 parts.append(f"posts_collected = ${idx}") args.append(posts_collected) if mentions_collected is not None: idx += 1 parts.append(f"mentions_collected = ${idx}") args.append(mentions_collected) if not parts: return sql = f"UPDATE collection_runs SET {', '.join(parts)} WHERE id = $1" await self._pool.execute(sql, run_id, *args) async def finish_run( self, run_id: int, status: str, errors: list[dict] | None = None ) -> None: await self._pool.execute( """ UPDATE collection_runs SET finished_at = now(), status = $2, errors = $3, duration_secs = EXTRACT(EPOCH FROM (now() - started_at)) WHERE id = $1 """, run_id, status, json.dumps(errors or []), ) # ── Stats (useful for verification) ───────────────────────────────── async def get_stats(self) -> dict[str, int]: row = await self._pool.fetchrow( """ SELECT (SELECT count(*) FROM accounts WHERE active) AS accounts, (SELECT count(*) FROM posts) AS posts, (SELECT count(*) FROM mentions) AS mentions, (SELECT count(*) FROM collection_runs) AS runs """ ) return dict(row)