Initial commit
This commit is contained in:
152
scripts/import_tags.py
Normal file
152
scripts/import_tags.py
Normal file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
One-time script to import data/all_tags.csv into db/tags.db (SQLite).
|
||||
|
||||
Creates:
|
||||
- `tags` table with a UNIQUE index on `name`
|
||||
- `tags_fts` FTS5 virtual table for fast prefix/full-text searches
|
||||
|
||||
Usage:
|
||||
python scripts/import_tags.py [--csv data/all_tags.csv] [--db db/tags.db]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Import Danbooru tags CSV into SQLite")
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
default=os.path.join(os.path.dirname(__file__), "..", "data", "all_tags.csv"),
|
||||
help="Path to the tags CSV file (default: data/all_tags.csv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db",
|
||||
default=os.path.join(os.path.dirname(__file__), "..", "db", "tags.db"),
|
||||
help="Path for the output SQLite database (default: db/tags.db)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=10_000,
|
||||
help="Number of rows to insert per transaction (default: 10000)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_schema(conn: sqlite3.Connection) -> None:
|
||||
conn.executescript("""
|
||||
PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
-- FTS5 virtual table for fast prefix and full-text search.
|
||||
-- content= links it to the `tags` table so the index is kept lean.
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS tags_fts USING fts5(
|
||||
name,
|
||||
content='tags',
|
||||
content_rowid='id',
|
||||
tokenize='unicode61 remove_diacritics 1'
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
|
||||
def import_csv(conn: sqlite3.Connection, csv_path: str, batch_size: int) -> int:
|
||||
"""Insert tags from CSV and return the total count inserted."""
|
||||
inserted = 0
|
||||
batch: list[tuple[str]] = []
|
||||
|
||||
with open(csv_path, newline="", encoding="utf-8") as fh:
|
||||
reader = csv.reader(fh)
|
||||
|
||||
# Skip header row
|
||||
header = next(reader, None)
|
||||
if header and header[0].strip().lower() in ("tag", "tags"):
|
||||
pass # consumed
|
||||
else:
|
||||
# Not a header — re-process as data
|
||||
if header:
|
||||
tag = header[0].rstrip(",").strip()
|
||||
if tag:
|
||||
batch.append((tag,))
|
||||
|
||||
for row in reader:
|
||||
if not row:
|
||||
continue
|
||||
# The CSV has values like "some_tag," — strip trailing comma and whitespace
|
||||
tag = row[0].rstrip(",").strip()
|
||||
if not tag:
|
||||
continue
|
||||
batch.append((tag,))
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
_flush(conn, batch)
|
||||
inserted += len(batch)
|
||||
batch = []
|
||||
print(f"\r {inserted:,} tags imported…", end="", flush=True)
|
||||
|
||||
if batch:
|
||||
_flush(conn, batch)
|
||||
inserted += len(batch)
|
||||
|
||||
return inserted
|
||||
|
||||
|
||||
def _flush(conn: sqlite3.Connection, batch: list[tuple[str]]) -> None:
|
||||
conn.executemany(
|
||||
"INSERT OR IGNORE INTO tags (name) VALUES (?)",
|
||||
batch,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def rebuild_fts(conn: sqlite3.Connection) -> None:
|
||||
"""Populate the FTS5 index from the `tags` table."""
|
||||
print("\n Rebuilding FTS5 index…", flush=True)
|
||||
conn.execute("INSERT INTO tags_fts(tags_fts) VALUES('rebuild')")
|
||||
conn.commit()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
csv_path = os.path.abspath(args.csv)
|
||||
db_path = os.path.abspath(args.db)
|
||||
|
||||
if not os.path.isfile(csv_path):
|
||||
print(f"ERROR: CSV file not found: {csv_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
print(f"Source : {csv_path}")
|
||||
print(f"Target : {db_path}")
|
||||
print(f"Batch : {args.batch_size:,} rows per transaction")
|
||||
print()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
create_schema(conn)
|
||||
total = import_csv(conn, csv_path, args.batch_size)
|
||||
rebuild_fts(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
print(f"\nDone. {total:,} tags imported in {elapsed:.1f}s → {db_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
359
scripts/scrape_tags.py
Normal file
359
scripts/scrape_tags.py
Normal file
@@ -0,0 +1,359 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Scrape Danbooru tags from the public API, sorted by post count (descending).
|
||||
|
||||
All tags with ≥10 posts fit within approximately the first 300 pages
|
||||
(1000 tags/page, sorted by post_count DESC). The scraper stops automatically
|
||||
as soon as an entire page consists only of tags below --min-posts.
|
||||
|
||||
This approach is fast (~3–5 minutes), clean, and requires no complex cursor
|
||||
or ID-based pagination — just standard page-offset requests.
|
||||
|
||||
The scrape is resumable: if interrupted, re-run and it will continue from
|
||||
the last completed page.
|
||||
|
||||
Usage:
|
||||
python scripts/scrape_tags.py [--db db/tags.db]
|
||||
|
||||
Environment (from .env or shell):
|
||||
DANBOORU_USER Danbooru login name
|
||||
DANBOORU_API_KEY Danbooru API key
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
try:
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
except ImportError:
|
||||
print("ERROR: 'requests' not installed. Run: pip install requests")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BASE_URL = "https://danbooru.donmai.us"
|
||||
PAGE_LIMIT = 1000
|
||||
DEFAULT_DB = Path(__file__).parent.parent / "db" / "tags.db"
|
||||
REQUEST_DELAY = 0.25 # seconds between requests per worker
|
||||
MIN_POST_COUNT = 10
|
||||
MAX_PAGES = 500 # safety cap (all ≥10-post tags are < 300 pages)
|
||||
|
||||
CATEGORY_NAMES: dict[int, str] = {
|
||||
0: "general",
|
||||
1: "artist",
|
||||
3: "copyright",
|
||||
4: "character",
|
||||
5: "meta",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# .env loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_env() -> dict[str, str]:
|
||||
env: dict[str, str] = {}
|
||||
for candidate in [
|
||||
Path(__file__).parent.parent / ".env",
|
||||
Path.home() / ".env",
|
||||
]:
|
||||
if candidate.exists():
|
||||
for line in candidate.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, _, v = line.partition("=")
|
||||
env.setdefault(k.strip(), v.strip())
|
||||
break
|
||||
return env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CREATE_SQL = """
|
||||
PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
post_count INTEGER NOT NULL DEFAULT 0,
|
||||
category INTEGER NOT NULL DEFAULT 0,
|
||||
category_name TEXT NOT NULL DEFAULT 'general',
|
||||
is_deprecated INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tags_name ON tags (name);
|
||||
CREATE INDEX IF NOT EXISTS idx_tags_post_count ON tags (post_count DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_tags_category ON tags (category);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS tags_fts USING fts5(
|
||||
name,
|
||||
content='tags',
|
||||
content_rowid='id',
|
||||
tokenize='unicode61 remove_diacritics 1'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS completed_pages (
|
||||
page INTEGER PRIMARY KEY
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
def init_db(db_path: Path) -> sqlite3.Connection:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
||||
conn.executescript(CREATE_SQL)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def get_completed_pages(conn: sqlite3.Connection) -> set[int]:
|
||||
rows = conn.execute("SELECT page FROM completed_pages").fetchall()
|
||||
return {r[0] for r in rows}
|
||||
|
||||
|
||||
def mark_page_done(conn: sqlite3.Connection, page: int) -> None:
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO completed_pages (page) VALUES (?)", (page,)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def upsert_tags(conn: sqlite3.Connection, tags: list[dict], min_post_count: int) -> int:
|
||||
rows = [
|
||||
(
|
||||
t["id"],
|
||||
t["name"],
|
||||
t.get("post_count", 0),
|
||||
t.get("category", 0),
|
||||
CATEGORY_NAMES.get(t.get("category", 0), "general"),
|
||||
1 if t.get("is_deprecated") else 0,
|
||||
)
|
||||
for t in tags
|
||||
if t.get("post_count", 0) >= min_post_count
|
||||
]
|
||||
if not rows:
|
||||
return 0
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO tags (id, name, post_count, category, category_name, is_deprecated)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
post_count = excluded.post_count,
|
||||
category = excluded.category,
|
||||
category_name = excluded.category_name,
|
||||
is_deprecated = excluded.is_deprecated
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
conn.commit()
|
||||
return len(rows)
|
||||
|
||||
|
||||
def rebuild_fts(conn: sqlite3.Connection) -> None:
|
||||
print("Rebuilding FTS5 index…", flush=True)
|
||||
conn.execute("INSERT INTO tags_fts(tags_fts) VALUES('rebuild')")
|
||||
conn.commit()
|
||||
print("FTS5 index built.", flush=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_session(api_key: str | None, username: str | None) -> requests.Session:
|
||||
session = requests.Session()
|
||||
if api_key and username:
|
||||
session.auth = (username, api_key)
|
||||
session.headers.update({"User-Agent": "danbooru-mcp/0.1"})
|
||||
retry = Retry(
|
||||
total=6, backoff_factor=2.0,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["GET"],
|
||||
)
|
||||
session.mount("https://", HTTPAdapter(max_retries=retry))
|
||||
return session
|
||||
|
||||
|
||||
def fetch_page(session: requests.Session, page: int) -> list[dict]:
|
||||
params = {
|
||||
"limit": PAGE_LIMIT,
|
||||
"search[order]": "count",
|
||||
"page": page,
|
||||
}
|
||||
resp = session.get(f"{BASE_URL}/tags.json", params=params, timeout=30)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Scrape Danbooru tags (sorted by post count) into SQLite"
|
||||
)
|
||||
parser.add_argument("--db", default=str(DEFAULT_DB), help="Output SQLite DB path")
|
||||
parser.add_argument(
|
||||
"--min-posts", type=int, default=MIN_POST_COUNT,
|
||||
help=f"Stop when a page has no tags above this threshold (default: {MIN_POST_COUNT})"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=4,
|
||||
help="Parallel HTTP workers (default: 4)"
|
||||
)
|
||||
parser.add_argument("--no-resume", action="store_true", help="Start from scratch")
|
||||
parser.add_argument("--no-fts", action="store_true", help="Skip FTS5 rebuild")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _do_page(session: requests.Session, page: int) -> tuple[int, list[dict]]:
|
||||
"""Fetch a page and return (page, tags). Runs in thread pool."""
|
||||
time.sleep(REQUEST_DELAY)
|
||||
tags = fetch_page(session, page)
|
||||
return page, tags
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
db_path = Path(args.db).resolve()
|
||||
|
||||
env = _load_env()
|
||||
api_key = env.get("DANBOORU_API_KEY") or os.environ.get("DANBOORU_API_KEY")
|
||||
username = env.get("DANBOORU_USER") or os.environ.get("DANBOORU_USER")
|
||||
if not username:
|
||||
username = env.get("DANBOORU_USERNAME") or os.environ.get("DANBOORU_USERNAME")
|
||||
|
||||
print(f"Database : {db_path}")
|
||||
print(f"Min posts : {args.min_posts} (skip pages where all tags are below this)")
|
||||
print(f"Workers : {args.workers}")
|
||||
print(f"Auth : {'yes (' + username + ')' if (api_key and username) else 'none (public API)'}")
|
||||
print()
|
||||
|
||||
conn = init_db(db_path)
|
||||
session = make_session(api_key if (api_key and username) else None, username)
|
||||
|
||||
if args.no_resume:
|
||||
print("Resetting…")
|
||||
conn.execute("DELETE FROM tags")
|
||||
conn.execute("DELETE FROM completed_pages")
|
||||
conn.commit()
|
||||
done_pages: set[int] = set()
|
||||
else:
|
||||
done_pages = get_completed_pages(conn)
|
||||
existing = conn.execute("SELECT COUNT(*) FROM tags").fetchone()[0]
|
||||
if done_pages:
|
||||
print(f"Resuming — {len(done_pages)} pages done ({existing:,} tags stored)")
|
||||
else:
|
||||
print(f"Starting fresh ({existing:,} tags in DB)")
|
||||
|
||||
print()
|
||||
|
||||
total_tags = conn.execute("SELECT COUNT(*) FROM tags").fetchone()[0]
|
||||
pages_done = 0
|
||||
t0 = time.perf_counter()
|
||||
stop_flag = False
|
||||
|
||||
REPORT_EVERY = 10
|
||||
|
||||
print(f"{'Page':>6} {'Done':>6} {'Min posts':>10} {'Tags':>10} {'Rate':>7} {'Elapsed':>8}")
|
||||
print("-" * 60)
|
||||
|
||||
# We submit pages in batches of `workers`, process results in page order,
|
||||
# and stop as soon as we get a page where all tags are below min_posts.
|
||||
page = 1
|
||||
with ThreadPoolExecutor(max_workers=args.workers) as pool:
|
||||
while not stop_flag and page <= MAX_PAGES:
|
||||
# Submit a window of pages
|
||||
batch_pages = []
|
||||
for _ in range(args.workers):
|
||||
while page <= MAX_PAGES and page in done_pages:
|
||||
page += 1
|
||||
if page > MAX_PAGES:
|
||||
break
|
||||
batch_pages.append(page)
|
||||
page += 1
|
||||
|
||||
if not batch_pages:
|
||||
break
|
||||
|
||||
futures = {
|
||||
pool.submit(_do_page, session, p): p
|
||||
for p in batch_pages
|
||||
}
|
||||
|
||||
# Collect results in page order
|
||||
results: dict[int, list[dict]] = {}
|
||||
for fut in as_completed(futures):
|
||||
pg, tags = fut.result()
|
||||
results[pg] = tags
|
||||
|
||||
for pg in sorted(results.keys()):
|
||||
tags = results[pg]
|
||||
|
||||
if not tags:
|
||||
print(f"\nPage {pg}: empty response. Stopping.")
|
||||
stop_flag = True
|
||||
break
|
||||
|
||||
max_in_page = max(t.get("post_count", 0) for t in tags)
|
||||
min_in_page = min(t.get("post_count", 0) for t in tags)
|
||||
|
||||
if max_in_page < args.min_posts:
|
||||
print(f"\nPage {pg}: all tags have <{args.min_posts} posts (min={min_in_page}). Stopping.")
|
||||
stop_flag = True
|
||||
break
|
||||
|
||||
stored = upsert_tags(conn, tags, min_post_count=args.min_posts)
|
||||
mark_page_done(conn, pg)
|
||||
total_tags += stored
|
||||
pages_done += 1
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
rate = pages_done / elapsed if elapsed > 0 else 0
|
||||
|
||||
line = (
|
||||
f"{pg:>6} {pages_done:>6} {min_in_page:>10,} "
|
||||
f"{total_tags:>10,} {rate:>5.1f}/s {elapsed/60:>6.1f}m"
|
||||
)
|
||||
if pages_done % REPORT_EVERY == 0:
|
||||
print(line, flush=True)
|
||||
else:
|
||||
print(f"\r{line}", end="", flush=True)
|
||||
|
||||
elapsed_total = time.perf_counter() - t0
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Scraping complete:")
|
||||
print(f" Tags stored : {total_tags:,}")
|
||||
print(f" Pages done : {pages_done:,}")
|
||||
print(f" Time : {elapsed_total/60:.1f} minutes")
|
||||
print()
|
||||
|
||||
if not args.no_fts:
|
||||
rebuild_fts(conn)
|
||||
|
||||
conn.close()
|
||||
print(f"Database saved to {db_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
221
scripts/test_danbooru_api.py
Normal file
221
scripts/test_danbooru_api.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the Danbooru API.
|
||||
|
||||
Verifies:
|
||||
1. Authentication with the API key works
|
||||
2. Tag listing endpoint returns expected fields
|
||||
3. Pagination works (multiple pages)
|
||||
4. Tag search / filtering by category works
|
||||
|
||||
Usage:
|
||||
python scripts/test_danbooru_api.py
|
||||
|
||||
Reads DANBOORU_API_KEY from .env or environment.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print("ERROR: 'requests' is not installed. Run: pip install requests")
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load .env
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_env() -> dict[str, str]:
|
||||
env: dict[str, str] = {}
|
||||
env_path = Path(__file__).parent.parent / ".env"
|
||||
if env_path.exists():
|
||||
for line in env_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, _, v = line.partition("=")
|
||||
env[k.strip()] = v.strip()
|
||||
return env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BASE_URL = "https://danbooru.donmai.us"
|
||||
|
||||
# Danbooru tag categories
|
||||
CATEGORY_NAMES = {
|
||||
0: "general",
|
||||
1: "artist",
|
||||
3: "copyright",
|
||||
4: "character",
|
||||
5: "meta",
|
||||
}
|
||||
|
||||
|
||||
def make_session(api_key: str | None = None, username: str | None = None) -> requests.Session:
|
||||
"""Create a requests Session.
|
||||
|
||||
Danbooru public endpoints (tag listing, searching) do not require
|
||||
authentication. Auth is only needed for account-specific actions.
|
||||
When provided, credentials must be (login, api_key) — NOT (user, api_key).
|
||||
"""
|
||||
session = requests.Session()
|
||||
if api_key and username:
|
||||
session.auth = (username, api_key)
|
||||
session.headers.update({"User-Agent": "danbooru-mcp-test/0.1"})
|
||||
return session
|
||||
|
||||
|
||||
def get_tags_page(
|
||||
session: requests.Session,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
search_name: str | None = None,
|
||||
search_category: int | None = None,
|
||||
order: str = "count", # "count" | "name" | "date"
|
||||
) -> list[dict]:
|
||||
"""Fetch one page of tags from the Danbooru API."""
|
||||
params: dict = {
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"search[order]": order,
|
||||
}
|
||||
if search_name:
|
||||
params["search[name_matches]"] = search_name
|
||||
if search_category is not None:
|
||||
params["search[category]"] = search_category
|
||||
|
||||
resp = session.get(f"{BASE_URL}/tags.json", params=params, timeout=15)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_basic_fetch(session: requests.Session) -> None:
|
||||
print("\n[1] Basic fetch — top 5 tags by post count")
|
||||
tags = get_tags_page(session, page=1, limit=5, order="count")
|
||||
assert isinstance(tags, list), f"Expected list, got {type(tags)}"
|
||||
assert len(tags) > 0, "No tags returned"
|
||||
|
||||
for tag in tags:
|
||||
cat = CATEGORY_NAMES.get(tag.get("category", -1), "unknown")
|
||||
print(f" [{cat:12s}] {tag['name']:40s} posts={tag['post_count']:>8,}")
|
||||
|
||||
print(" PASS ✓")
|
||||
|
||||
|
||||
def test_fields_present(session: requests.Session) -> None:
|
||||
print("\n[2] Field presence check")
|
||||
tags = get_tags_page(session, page=1, limit=1, order="count")
|
||||
tag = tags[0]
|
||||
required = {"id", "name", "post_count", "category", "is_deprecated", "words"}
|
||||
missing = required - set(tag.keys())
|
||||
assert not missing, f"Missing fields: {missing}"
|
||||
print(f" Fields present: {sorted(tag.keys())}")
|
||||
print(f" Sample tag: name={tag['name']!r} category={CATEGORY_NAMES.get(tag['category'])} deprecated={tag['is_deprecated']}")
|
||||
print(" PASS ✓")
|
||||
|
||||
|
||||
def test_pagination(session: requests.Session) -> None:
|
||||
print("\n[3] Pagination — page 1 vs page 2 should differ")
|
||||
p1 = get_tags_page(session, page=1, limit=5, order="count")
|
||||
time.sleep(0.5)
|
||||
p2 = get_tags_page(session, page=2, limit=5, order="count")
|
||||
names_p1 = {t["name"] for t in p1}
|
||||
names_p2 = {t["name"] for t in p2}
|
||||
overlap = names_p1 & names_p2
|
||||
assert not overlap, f"Pages 1 and 2 share tags: {overlap}"
|
||||
print(f" Page 1: {sorted(names_p1)}")
|
||||
print(f" Page 2: {sorted(names_p2)}")
|
||||
print(" PASS ✓")
|
||||
|
||||
|
||||
def test_category_filter(session: requests.Session) -> None:
|
||||
print("\n[4] Category filter — fetch only 'character' tags (category=4)")
|
||||
tags = get_tags_page(session, page=1, limit=5, search_category=4, order="count")
|
||||
for tag in tags:
|
||||
assert tag["category"] == 4, f"Expected category 4, got {tag['category']} for {tag['name']}"
|
||||
print(f" {tag['name']:40s} posts={tag['post_count']:>8,}")
|
||||
print(" PASS ✓")
|
||||
|
||||
|
||||
def test_name_search(session: requests.Session) -> None:
|
||||
print("\n[5] Name search — tags matching 'blue_hair*'")
|
||||
tags = get_tags_page(session, page=1, limit=5, search_name="blue_hair*", order="count")
|
||||
assert len(tags) > 0, "No results for blue_hair*"
|
||||
for tag in tags:
|
||||
cat = CATEGORY_NAMES.get(tag.get("category", -1), "unknown")
|
||||
print(f" [{cat:12s}] {tag['name']:40s} posts={tag['post_count']:>8,}")
|
||||
print(" PASS ✓")
|
||||
|
||||
|
||||
def test_well_known_tags(session: requests.Session) -> None:
|
||||
print("\n[6] Well-known tags — '1girl', 'blue_hair', 'sword' should exist")
|
||||
for tag_name in ("1girl", "blue_hair", "sword"):
|
||||
tags = get_tags_page(session, page=1, limit=1, search_name=tag_name, order="count")
|
||||
found = [t for t in tags if t["name"] == tag_name]
|
||||
assert found, f"Tag '{tag_name}' not found in API response"
|
||||
t = found[0]
|
||||
cat = CATEGORY_NAMES.get(t.get("category", -1), "unknown")
|
||||
print(f" {tag_name:20s} category={cat:12s} posts={t['post_count']:>8,}")
|
||||
print(" PASS ✓")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main() -> None:
|
||||
env = load_env()
|
||||
api_key = env.get("DANBOORU_API_KEY") or os.environ.get("DANBOORU_API_KEY")
|
||||
username = env.get("DANBOORU_USERNAME") or os.environ.get("DANBOORU_USERNAME")
|
||||
|
||||
if api_key:
|
||||
print(f"API key loaded: {api_key[:8]}…")
|
||||
else:
|
||||
print("No API key found — using unauthenticated access (public endpoints only)")
|
||||
|
||||
# Danbooru public tag endpoints don't require auth.
|
||||
# Pass username + api_key only when both are available.
|
||||
session = make_session(
|
||||
api_key=api_key if (api_key and username) else None,
|
||||
username=username,
|
||||
)
|
||||
|
||||
tests = [
|
||||
test_basic_fetch,
|
||||
test_fields_present,
|
||||
test_pagination,
|
||||
test_category_filter,
|
||||
test_name_search,
|
||||
test_well_known_tags,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
for test_fn in tests:
|
||||
try:
|
||||
test_fn(session)
|
||||
passed += 1
|
||||
except Exception as exc:
|
||||
print(f" FAIL ✗ {exc}")
|
||||
failed += 1
|
||||
time.sleep(0.3) # be polite to the API
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Results: {passed} passed, {failed} failed")
|
||||
if failed:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user