bluesky-collector/src/db.py

266 lines
9.6 KiB
Python
Raw Normal View History

"""Async PostgreSQL database layer using asyncpg."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any
import asyncpg
from .models import Account, CollectionState, Mention, Post
logger = logging.getLogger(__name__)
class Database:
def __init__(self, dsn: str):
self._dsn = dsn
self._pool: asyncpg.Pool | None = None
async def connect(self) -> None:
self._pool = await asyncpg.create_pool(self._dsn, min_size=2, max_size=5)
logger.info("Database connection pool created")
async def close(self) -> None:
if self._pool:
await self._pool.close()
logger.info("Database connection pool closed")
# ── Account operations ──────────────────────────────────────────────
async def upsert_account(self, account: Account) -> None:
await self._pool.execute(
"""
INSERT INTO accounts (did, handle, display_name)
VALUES ($1, $2, $3)
ON CONFLICT (did) DO UPDATE SET
handle = EXCLUDED.handle,
display_name = EXCLUDED.display_name
""",
account.did,
account.handle,
account.display_name,
)
async def deactivate_removed_accounts(self, active_dids: set[str]) -> None:
"""Set active=false for accounts no longer in the config."""
if not active_dids:
return
await self._pool.execute(
"""
UPDATE accounts SET active = false
WHERE did != ALL($1::text[]) AND active = true
""",
list(active_dids),
)
async def update_account_last_feed(self, did: str) -> None:
await self._pool.execute(
"UPDATE accounts SET last_feed_collected = now() WHERE did = $1", did
)
async def update_account_last_mention(self, did: str) -> None:
await self._pool.execute(
"UPDATE accounts SET last_mention_collected = now() WHERE did = $1", did
)
# ── Post operations ─────────────────────────────────────────────────
async def upsert_posts(self, posts: list[Post]) -> int:
"""Batch upsert posts. Returns the number of rows affected."""
if not posts:
return 0
count = 0
async with self._pool.acquire() as conn:
async with conn.transaction():
for p in posts:
result = await conn.execute(
"""
INSERT INTO posts (
uri, cid, author_did, text, created_at, indexed_at,
reply_parent, reply_root, post_type,
has_media, has_embed,
like_count, reply_count, repost_count, quote_count,
langs, raw_json
) VALUES (
$1, $2, $3, $4, $5, $6,
$7, $8, $9,
$10, $11,
$12, $13, $14, $15,
$16, $17
)
ON CONFLICT (uri) DO UPDATE SET
cid = EXCLUDED.cid,
like_count = EXCLUDED.like_count,
reply_count = EXCLUDED.reply_count,
repost_count = EXCLUDED.repost_count,
quote_count = EXCLUDED.quote_count,
collected_at = now()
""",
p.uri,
p.cid,
p.author_did,
p.text,
p.created_at,
p.indexed_at,
p.reply_parent,
p.reply_root,
p.post_type,
p.has_media,
p.has_embed,
p.like_count,
p.reply_count,
p.repost_count,
p.quote_count,
p.langs,
json.dumps(p.raw_json),
)
# asyncpg returns e.g. "INSERT 0 1"
count += 1
return count
# ── Mention operations ──────────────────────────────────────────────
async def upsert_mentions(self, mentions: list[Mention]) -> int:
if not mentions:
return 0
count = 0
async with self._pool.acquire() as conn:
async with conn.transaction():
for m in mentions:
result = await conn.execute(
"""
INSERT INTO mentions (
post_uri, mentioned_did, mentioning_did,
post_text, post_created_at, raw_json
) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (post_uri, mentioned_did) DO NOTHING
""",
m.post_uri,
m.mentioned_did,
m.mentioning_did,
m.post_text,
m.post_created_at,
json.dumps(m.raw_json),
)
if "INSERT 0 1" in result:
count += 1
return count
# ── Collection state ────────────────────────────────────────────────
async def get_collection_state(
self, account_did: str, collection_type: str
) -> CollectionState | None:
row = await self._pool.fetchrow(
"""
SELECT account_did, collection_type, last_post_at
FROM collection_state
WHERE account_did = $1 AND collection_type = $2
""",
account_did,
collection_type,
)
if not row:
return None
return CollectionState(
account_did=row["account_did"],
collection_type=row["collection_type"],
last_post_at=row["last_post_at"],
)
async def save_collection_state(
self, account_did: str, collection_type: str, last_post_at: datetime | None
) -> None:
await self._pool.execute(
"""
INSERT INTO collection_state (account_did, collection_type, last_post_at, updated_at)
VALUES ($1, $2, $3, now())
ON CONFLICT (account_did, collection_type) DO UPDATE SET
last_post_at = EXCLUDED.last_post_at,
updated_at = now()
""",
account_did,
collection_type,
last_post_at,
)
# ── Collection run tracking ─────────────────────────────────────────
async def start_run(self, accounts_total: int) -> int:
row = await self._pool.fetchrow(
"""
INSERT INTO collection_runs (accounts_total)
VALUES ($1)
RETURNING id
""",
accounts_total,
)
return row["id"]
async def update_run_progress(
self,
run_id: int,
*,
accounts_done: int | None = None,
posts_collected: int | None = None,
mentions_collected: int | None = None,
) -> None:
parts = []
args: list[Any] = []
idx = 1
if accounts_done is not None:
idx += 1
parts.append(f"accounts_done = ${idx}")
args.append(accounts_done)
if posts_collected is not None:
idx += 1
parts.append(f"posts_collected = ${idx}")
args.append(posts_collected)
if mentions_collected is not None:
idx += 1
parts.append(f"mentions_collected = ${idx}")
args.append(mentions_collected)
if not parts:
return
sql = f"UPDATE collection_runs SET {', '.join(parts)} WHERE id = $1"
await self._pool.execute(sql, run_id, *args)
async def finish_run(
self, run_id: int, status: str, errors: list[dict] | None = None
) -> None:
await self._pool.execute(
"""
UPDATE collection_runs SET
finished_at = now(),
status = $2,
errors = $3,
duration_secs = EXTRACT(EPOCH FROM (now() - started_at))
WHERE id = $1
""",
run_id,
status,
json.dumps(errors or []),
)
# ── Stats (useful for verification) ─────────────────────────────────
async def get_stats(self) -> dict[str, int]:
row = await self._pool.fetchrow(
"""
SELECT
(SELECT count(*) FROM accounts WHERE active) AS accounts,
(SELECT count(*) FROM posts) AS posts,
(SELECT count(*) FROM mentions) AS mentions,
(SELECT count(*) FROM collection_runs) AS runs
"""
)
return dict(row)