Initial commit

This commit is contained in:
Aodhan Collins
2026-03-02 23:29:58 +00:00
commit 08c6e14616
12 changed files with 2121 additions and 0 deletions

152
scripts/import_tags.py Normal file
View File

@@ -0,0 +1,152 @@
#!/usr/bin/env python3
"""
One-time script to import data/all_tags.csv into db/tags.db (SQLite).
Creates:
- `tags` table with a UNIQUE index on `name`
- `tags_fts` FTS5 virtual table for fast prefix/full-text searches
Usage:
python scripts/import_tags.py [--csv data/all_tags.csv] [--db db/tags.db]
"""
import argparse
import csv
import os
import sqlite3
import sys
import time
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Import Danbooru tags CSV into SQLite")
parser.add_argument(
"--csv",
default=os.path.join(os.path.dirname(__file__), "..", "data", "all_tags.csv"),
help="Path to the tags CSV file (default: data/all_tags.csv)",
)
parser.add_argument(
"--db",
default=os.path.join(os.path.dirname(__file__), "..", "db", "tags.db"),
help="Path for the output SQLite database (default: db/tags.db)",
)
parser.add_argument(
"--batch-size",
type=int,
default=10_000,
help="Number of rows to insert per transaction (default: 10000)",
)
return parser.parse_args()
def create_schema(conn: sqlite3.Connection) -> None:
conn.executescript("""
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE
);
-- FTS5 virtual table for fast prefix and full-text search.
-- content= links it to the `tags` table so the index is kept lean.
CREATE VIRTUAL TABLE IF NOT EXISTS tags_fts USING fts5(
name,
content='tags',
content_rowid='id',
tokenize='unicode61 remove_diacritics 1'
);
""")
conn.commit()
def import_csv(conn: sqlite3.Connection, csv_path: str, batch_size: int) -> int:
"""Insert tags from CSV and return the total count inserted."""
inserted = 0
batch: list[tuple[str]] = []
with open(csv_path, newline="", encoding="utf-8") as fh:
reader = csv.reader(fh)
# Skip header row
header = next(reader, None)
if header and header[0].strip().lower() in ("tag", "tags"):
pass # consumed
else:
# Not a header — re-process as data
if header:
tag = header[0].rstrip(",").strip()
if tag:
batch.append((tag,))
for row in reader:
if not row:
continue
# The CSV has values like "some_tag," — strip trailing comma and whitespace
tag = row[0].rstrip(",").strip()
if not tag:
continue
batch.append((tag,))
if len(batch) >= batch_size:
_flush(conn, batch)
inserted += len(batch)
batch = []
print(f"\r {inserted:,} tags imported…", end="", flush=True)
if batch:
_flush(conn, batch)
inserted += len(batch)
return inserted
def _flush(conn: sqlite3.Connection, batch: list[tuple[str]]) -> None:
conn.executemany(
"INSERT OR IGNORE INTO tags (name) VALUES (?)",
batch,
)
conn.commit()
def rebuild_fts(conn: sqlite3.Connection) -> None:
"""Populate the FTS5 index from the `tags` table."""
print("\n Rebuilding FTS5 index…", flush=True)
conn.execute("INSERT INTO tags_fts(tags_fts) VALUES('rebuild')")
conn.commit()
def main() -> None:
args = parse_args()
csv_path = os.path.abspath(args.csv)
db_path = os.path.abspath(args.db)
if not os.path.isfile(csv_path):
print(f"ERROR: CSV file not found: {csv_path}", file=sys.stderr)
sys.exit(1)
os.makedirs(os.path.dirname(db_path), exist_ok=True)
print(f"Source : {csv_path}")
print(f"Target : {db_path}")
print(f"Batch : {args.batch_size:,} rows per transaction")
print()
t0 = time.perf_counter()
conn = sqlite3.connect(db_path)
try:
create_schema(conn)
total = import_csv(conn, csv_path, args.batch_size)
rebuild_fts(conn)
finally:
conn.close()
elapsed = time.perf_counter() - t0
print(f"\nDone. {total:,} tags imported in {elapsed:.1f}s → {db_path}")
if __name__ == "__main__":
main()

359
scripts/scrape_tags.py Normal file
View File

@@ -0,0 +1,359 @@
#!/usr/bin/env python3
"""
Scrape Danbooru tags from the public API, sorted by post count (descending).
All tags with ≥10 posts fit within approximately the first 300 pages
(1000 tags/page, sorted by post_count DESC). The scraper stops automatically
as soon as an entire page consists only of tags below --min-posts.
This approach is fast (~35 minutes), clean, and requires no complex cursor
or ID-based pagination — just standard page-offset requests.
The scrape is resumable: if interrupted, re-run and it will continue from
the last completed page.
Usage:
python scripts/scrape_tags.py [--db db/tags.db]
Environment (from .env or shell):
DANBOORU_USER Danbooru login name
DANBOORU_API_KEY Danbooru API key
"""
from __future__ import annotations
import argparse
import os
import sqlite3
import sys
import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
try:
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
except ImportError:
print("ERROR: 'requests' not installed. Run: pip install requests")
sys.exit(1)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
BASE_URL = "https://danbooru.donmai.us"
PAGE_LIMIT = 1000
DEFAULT_DB = Path(__file__).parent.parent / "db" / "tags.db"
REQUEST_DELAY = 0.25 # seconds between requests per worker
MIN_POST_COUNT = 10
MAX_PAGES = 500 # safety cap (all ≥10-post tags are < 300 pages)
CATEGORY_NAMES: dict[int, str] = {
0: "general",
1: "artist",
3: "copyright",
4: "character",
5: "meta",
}
# ---------------------------------------------------------------------------
# .env loader
# ---------------------------------------------------------------------------
def _load_env() -> dict[str, str]:
env: dict[str, str] = {}
for candidate in [
Path(__file__).parent.parent / ".env",
Path.home() / ".env",
]:
if candidate.exists():
for line in candidate.read_text().splitlines():
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, _, v = line.partition("=")
env.setdefault(k.strip(), v.strip())
break
return env
# ---------------------------------------------------------------------------
# Database
# ---------------------------------------------------------------------------
CREATE_SQL = """
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
post_count INTEGER NOT NULL DEFAULT 0,
category INTEGER NOT NULL DEFAULT 0,
category_name TEXT NOT NULL DEFAULT 'general',
is_deprecated INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_tags_name ON tags (name);
CREATE INDEX IF NOT EXISTS idx_tags_post_count ON tags (post_count DESC);
CREATE INDEX IF NOT EXISTS idx_tags_category ON tags (category);
CREATE VIRTUAL TABLE IF NOT EXISTS tags_fts USING fts5(
name,
content='tags',
content_rowid='id',
tokenize='unicode61 remove_diacritics 1'
);
CREATE TABLE IF NOT EXISTS completed_pages (
page INTEGER PRIMARY KEY
);
"""
def init_db(db_path: Path) -> sqlite3.Connection:
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path), check_same_thread=False)
conn.executescript(CREATE_SQL)
conn.commit()
return conn
def get_completed_pages(conn: sqlite3.Connection) -> set[int]:
rows = conn.execute("SELECT page FROM completed_pages").fetchall()
return {r[0] for r in rows}
def mark_page_done(conn: sqlite3.Connection, page: int) -> None:
conn.execute(
"INSERT OR IGNORE INTO completed_pages (page) VALUES (?)", (page,)
)
conn.commit()
def upsert_tags(conn: sqlite3.Connection, tags: list[dict], min_post_count: int) -> int:
rows = [
(
t["id"],
t["name"],
t.get("post_count", 0),
t.get("category", 0),
CATEGORY_NAMES.get(t.get("category", 0), "general"),
1 if t.get("is_deprecated") else 0,
)
for t in tags
if t.get("post_count", 0) >= min_post_count
]
if not rows:
return 0
conn.executemany(
"""
INSERT INTO tags (id, name, post_count, category, category_name, is_deprecated)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
post_count = excluded.post_count,
category = excluded.category,
category_name = excluded.category_name,
is_deprecated = excluded.is_deprecated
""",
rows,
)
conn.commit()
return len(rows)
def rebuild_fts(conn: sqlite3.Connection) -> None:
print("Rebuilding FTS5 index…", flush=True)
conn.execute("INSERT INTO tags_fts(tags_fts) VALUES('rebuild')")
conn.commit()
print("FTS5 index built.", flush=True)
# ---------------------------------------------------------------------------
# HTTP
# ---------------------------------------------------------------------------
def make_session(api_key: str | None, username: str | None) -> requests.Session:
session = requests.Session()
if api_key and username:
session.auth = (username, api_key)
session.headers.update({"User-Agent": "danbooru-mcp/0.1"})
retry = Retry(
total=6, backoff_factor=2.0,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["GET"],
)
session.mount("https://", HTTPAdapter(max_retries=retry))
return session
def fetch_page(session: requests.Session, page: int) -> list[dict]:
params = {
"limit": PAGE_LIMIT,
"search[order]": "count",
"page": page,
}
resp = session.get(f"{BASE_URL}/tags.json", params=params, timeout=30)
resp.raise_for_status()
return resp.json()
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Scrape Danbooru tags (sorted by post count) into SQLite"
)
parser.add_argument("--db", default=str(DEFAULT_DB), help="Output SQLite DB path")
parser.add_argument(
"--min-posts", type=int, default=MIN_POST_COUNT,
help=f"Stop when a page has no tags above this threshold (default: {MIN_POST_COUNT})"
)
parser.add_argument(
"--workers", type=int, default=4,
help="Parallel HTTP workers (default: 4)"
)
parser.add_argument("--no-resume", action="store_true", help="Start from scratch")
parser.add_argument("--no-fts", action="store_true", help="Skip FTS5 rebuild")
return parser.parse_args()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def _do_page(session: requests.Session, page: int) -> tuple[int, list[dict]]:
"""Fetch a page and return (page, tags). Runs in thread pool."""
time.sleep(REQUEST_DELAY)
tags = fetch_page(session, page)
return page, tags
def main() -> None:
args = parse_args()
db_path = Path(args.db).resolve()
env = _load_env()
api_key = env.get("DANBOORU_API_KEY") or os.environ.get("DANBOORU_API_KEY")
username = env.get("DANBOORU_USER") or os.environ.get("DANBOORU_USER")
if not username:
username = env.get("DANBOORU_USERNAME") or os.environ.get("DANBOORU_USERNAME")
print(f"Database : {db_path}")
print(f"Min posts : {args.min_posts} (skip pages where all tags are below this)")
print(f"Workers : {args.workers}")
print(f"Auth : {'yes (' + username + ')' if (api_key and username) else 'none (public API)'}")
print()
conn = init_db(db_path)
session = make_session(api_key if (api_key and username) else None, username)
if args.no_resume:
print("Resetting…")
conn.execute("DELETE FROM tags")
conn.execute("DELETE FROM completed_pages")
conn.commit()
done_pages: set[int] = set()
else:
done_pages = get_completed_pages(conn)
existing = conn.execute("SELECT COUNT(*) FROM tags").fetchone()[0]
if done_pages:
print(f"Resuming — {len(done_pages)} pages done ({existing:,} tags stored)")
else:
print(f"Starting fresh ({existing:,} tags in DB)")
print()
total_tags = conn.execute("SELECT COUNT(*) FROM tags").fetchone()[0]
pages_done = 0
t0 = time.perf_counter()
stop_flag = False
REPORT_EVERY = 10
print(f"{'Page':>6} {'Done':>6} {'Min posts':>10} {'Tags':>10} {'Rate':>7} {'Elapsed':>8}")
print("-" * 60)
# We submit pages in batches of `workers`, process results in page order,
# and stop as soon as we get a page where all tags are below min_posts.
page = 1
with ThreadPoolExecutor(max_workers=args.workers) as pool:
while not stop_flag and page <= MAX_PAGES:
# Submit a window of pages
batch_pages = []
for _ in range(args.workers):
while page <= MAX_PAGES and page in done_pages:
page += 1
if page > MAX_PAGES:
break
batch_pages.append(page)
page += 1
if not batch_pages:
break
futures = {
pool.submit(_do_page, session, p): p
for p in batch_pages
}
# Collect results in page order
results: dict[int, list[dict]] = {}
for fut in as_completed(futures):
pg, tags = fut.result()
results[pg] = tags
for pg in sorted(results.keys()):
tags = results[pg]
if not tags:
print(f"\nPage {pg}: empty response. Stopping.")
stop_flag = True
break
max_in_page = max(t.get("post_count", 0) for t in tags)
min_in_page = min(t.get("post_count", 0) for t in tags)
if max_in_page < args.min_posts:
print(f"\nPage {pg}: all tags have <{args.min_posts} posts (min={min_in_page}). Stopping.")
stop_flag = True
break
stored = upsert_tags(conn, tags, min_post_count=args.min_posts)
mark_page_done(conn, pg)
total_tags += stored
pages_done += 1
elapsed = time.perf_counter() - t0
rate = pages_done / elapsed if elapsed > 0 else 0
line = (
f"{pg:>6} {pages_done:>6} {min_in_page:>10,} "
f"{total_tags:>10,} {rate:>5.1f}/s {elapsed/60:>6.1f}m"
)
if pages_done % REPORT_EVERY == 0:
print(line, flush=True)
else:
print(f"\r{line}", end="", flush=True)
elapsed_total = time.perf_counter() - t0
print(f"\n{'='*60}")
print(f"Scraping complete:")
print(f" Tags stored : {total_tags:,}")
print(f" Pages done : {pages_done:,}")
print(f" Time : {elapsed_total/60:.1f} minutes")
print()
if not args.no_fts:
rebuild_fts(conn)
conn.close()
print(f"Database saved to {db_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,221 @@
#!/usr/bin/env python3
"""
Test script for the Danbooru API.
Verifies:
1. Authentication with the API key works
2. Tag listing endpoint returns expected fields
3. Pagination works (multiple pages)
4. Tag search / filtering by category works
Usage:
python scripts/test_danbooru_api.py
Reads DANBOORU_API_KEY from .env or environment.
"""
import json
import os
import sys
import time
from pathlib import Path
try:
import requests
except ImportError:
print("ERROR: 'requests' is not installed. Run: pip install requests")
sys.exit(1)
# ---------------------------------------------------------------------------
# Load .env
# ---------------------------------------------------------------------------
def load_env() -> dict[str, str]:
env: dict[str, str] = {}
env_path = Path(__file__).parent.parent / ".env"
if env_path.exists():
for line in env_path.read_text().splitlines():
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, _, v = line.partition("=")
env[k.strip()] = v.strip()
return env
# ---------------------------------------------------------------------------
# API helpers
# ---------------------------------------------------------------------------
BASE_URL = "https://danbooru.donmai.us"
# Danbooru tag categories
CATEGORY_NAMES = {
0: "general",
1: "artist",
3: "copyright",
4: "character",
5: "meta",
}
def make_session(api_key: str | None = None, username: str | None = None) -> requests.Session:
"""Create a requests Session.
Danbooru public endpoints (tag listing, searching) do not require
authentication. Auth is only needed for account-specific actions.
When provided, credentials must be (login, api_key) — NOT (user, api_key).
"""
session = requests.Session()
if api_key and username:
session.auth = (username, api_key)
session.headers.update({"User-Agent": "danbooru-mcp-test/0.1"})
return session
def get_tags_page(
session: requests.Session,
page: int = 1,
limit: int = 20,
search_name: str | None = None,
search_category: int | None = None,
order: str = "count", # "count" | "name" | "date"
) -> list[dict]:
"""Fetch one page of tags from the Danbooru API."""
params: dict = {
"page": page,
"limit": limit,
"search[order]": order,
}
if search_name:
params["search[name_matches]"] = search_name
if search_category is not None:
params["search[category]"] = search_category
resp = session.get(f"{BASE_URL}/tags.json", params=params, timeout=15)
resp.raise_for_status()
return resp.json()
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_basic_fetch(session: requests.Session) -> None:
print("\n[1] Basic fetch — top 5 tags by post count")
tags = get_tags_page(session, page=1, limit=5, order="count")
assert isinstance(tags, list), f"Expected list, got {type(tags)}"
assert len(tags) > 0, "No tags returned"
for tag in tags:
cat = CATEGORY_NAMES.get(tag.get("category", -1), "unknown")
print(f" [{cat:12s}] {tag['name']:40s} posts={tag['post_count']:>8,}")
print(" PASS ✓")
def test_fields_present(session: requests.Session) -> None:
print("\n[2] Field presence check")
tags = get_tags_page(session, page=1, limit=1, order="count")
tag = tags[0]
required = {"id", "name", "post_count", "category", "is_deprecated", "words"}
missing = required - set(tag.keys())
assert not missing, f"Missing fields: {missing}"
print(f" Fields present: {sorted(tag.keys())}")
print(f" Sample tag: name={tag['name']!r} category={CATEGORY_NAMES.get(tag['category'])} deprecated={tag['is_deprecated']}")
print(" PASS ✓")
def test_pagination(session: requests.Session) -> None:
print("\n[3] Pagination — page 1 vs page 2 should differ")
p1 = get_tags_page(session, page=1, limit=5, order="count")
time.sleep(0.5)
p2 = get_tags_page(session, page=2, limit=5, order="count")
names_p1 = {t["name"] for t in p1}
names_p2 = {t["name"] for t in p2}
overlap = names_p1 & names_p2
assert not overlap, f"Pages 1 and 2 share tags: {overlap}"
print(f" Page 1: {sorted(names_p1)}")
print(f" Page 2: {sorted(names_p2)}")
print(" PASS ✓")
def test_category_filter(session: requests.Session) -> None:
print("\n[4] Category filter — fetch only 'character' tags (category=4)")
tags = get_tags_page(session, page=1, limit=5, search_category=4, order="count")
for tag in tags:
assert tag["category"] == 4, f"Expected category 4, got {tag['category']} for {tag['name']}"
print(f" {tag['name']:40s} posts={tag['post_count']:>8,}")
print(" PASS ✓")
def test_name_search(session: requests.Session) -> None:
print("\n[5] Name search — tags matching 'blue_hair*'")
tags = get_tags_page(session, page=1, limit=5, search_name="blue_hair*", order="count")
assert len(tags) > 0, "No results for blue_hair*"
for tag in tags:
cat = CATEGORY_NAMES.get(tag.get("category", -1), "unknown")
print(f" [{cat:12s}] {tag['name']:40s} posts={tag['post_count']:>8,}")
print(" PASS ✓")
def test_well_known_tags(session: requests.Session) -> None:
print("\n[6] Well-known tags — '1girl', 'blue_hair', 'sword' should exist")
for tag_name in ("1girl", "blue_hair", "sword"):
tags = get_tags_page(session, page=1, limit=1, search_name=tag_name, order="count")
found = [t for t in tags if t["name"] == tag_name]
assert found, f"Tag '{tag_name}' not found in API response"
t = found[0]
cat = CATEGORY_NAMES.get(t.get("category", -1), "unknown")
print(f" {tag_name:20s} category={cat:12s} posts={t['post_count']:>8,}")
print(" PASS ✓")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
env = load_env()
api_key = env.get("DANBOORU_API_KEY") or os.environ.get("DANBOORU_API_KEY")
username = env.get("DANBOORU_USERNAME") or os.environ.get("DANBOORU_USERNAME")
if api_key:
print(f"API key loaded: {api_key[:8]}")
else:
print("No API key found — using unauthenticated access (public endpoints only)")
# Danbooru public tag endpoints don't require auth.
# Pass username + api_key only when both are available.
session = make_session(
api_key=api_key if (api_key and username) else None,
username=username,
)
tests = [
test_basic_fetch,
test_fields_present,
test_pagination,
test_category_filter,
test_name_search,
test_well_known_tags,
]
passed = 0
failed = 0
for test_fn in tests:
try:
test_fn(session)
passed += 1
except Exception as exc:
print(f" FAIL ✗ {exc}")
failed += 1
time.sleep(0.3) # be polite to the API
print(f"\n{'='*50}")
print(f"Results: {passed} passed, {failed} failed")
if failed:
sys.exit(1)
if __name__ == "__main__":
main()