Initial commit

This commit is contained in:
Aodhan Collins
2026-03-02 23:29:58 +00:00
commit 08c6e14616
12 changed files with 2121 additions and 0 deletions

445
src/server.py Normal file
View File

@@ -0,0 +1,445 @@
#!/usr/bin/env python3
"""
Danbooru Tag Validator — MCP Server
Exposes three tools for LLMs to work with Danbooru tags:
search_tags prefix / full-text search, returns rich tag objects
validate_tags exact-match check, returns valid/invalid split with metadata
suggest_tags autocomplete-style suggestions for a partial input
The SQLite database must be pre-built with scripts/scrape_tags.py.
Logging:
All log output goes to stderr (stdout is reserved for the MCP JSON-RPC protocol).
Log level is controlled by the LOG_LEVEL environment variable:
DEBUG, INFO (default), WARNING, ERROR, CRITICAL
Log format is controlled by LOG_FORMAT:
"text" (default) — human-readable with timestamps
"json" — structured JSON, one object per line (for log aggregators)
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import sys
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from mcp.server.fastmcp import FastMCP
# ---------------------------------------------------------------------------
# Logging setup — must go to stderr (stdout is the MCP transport)
# ---------------------------------------------------------------------------
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
LOG_FORMAT = os.environ.get("LOG_FORMAT", "text").lower()
class _JsonFormatter(logging.Formatter):
"""Emit one JSON object per log record to stderr."""
def format(self, record: logging.LogRecord) -> str: # noqa: A003
obj: dict = {
"ts": self.formatTime(record, "%Y-%m-%dT%H:%M:%S"),
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
}
if record.exc_info:
obj["exc"] = self.formatException(record.exc_info)
# Forward any extra keyword args as top-level fields
for k, v in record.__dict__.items():
if k not in {
"name", "msg", "args", "levelname", "levelno", "pathname",
"filename", "module", "exc_info", "exc_text", "stack_info",
"lineno", "funcName", "created", "msecs", "relativeCreated",
"thread", "threadName", "processName", "process", "message",
"taskName",
}:
obj[k] = v
return json.dumps(obj, default=str)
def _configure_logging() -> logging.Logger:
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(LOG_LEVEL)
if LOG_FORMAT == "json":
handler.setFormatter(_JsonFormatter())
else:
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s [%(levelname)-8s] %(name)s%(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
)
root = logging.getLogger()
root.setLevel(LOG_LEVEL)
root.addHandler(handler)
return logging.getLogger("danbooru_mcp")
log = _configure_logging()
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
_HERE = Path(__file__).parent
_DEFAULT_DB = _HERE.parent / "db" / "tags.db"
DB_PATH = Path(os.environ.get("DANBOORU_TAGS_DB", str(_DEFAULT_DB)))
# ---------------------------------------------------------------------------
# Database helpers
# ---------------------------------------------------------------------------
def _check_db() -> None:
if not DB_PATH.exists():
log.error("Database file not found", extra={"db_path": str(DB_PATH)})
raise FileNotFoundError(
f"Tags database not found at {DB_PATH}. "
"Run `python scripts/scrape_tags.py` first to build it."
)
@contextmanager
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
_check_db()
t0 = time.perf_counter()
conn = sqlite3.connect(str(DB_PATH), check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA query_only = ON")
conn.execute("PRAGMA cache_size = -64000") # 64 MB page cache
log.debug("DB connection opened", extra={"db_path": str(DB_PATH)})
try:
yield conn
finally:
conn.close()
elapsed_ms = (time.perf_counter() - t0) * 1000
log.debug("DB connection closed", extra={"elapsed_ms": round(elapsed_ms, 2)})
# ---------------------------------------------------------------------------
# MCP Server
# ---------------------------------------------------------------------------
mcp = FastMCP(
"danbooru-tags",
instructions=(
"Use this server to validate, search, and suggest Danbooru tags "
"for Stable Diffusion / Illustrious prompts. "
"Always call validate_tags before finalising a prompt to confirm "
"every tag is a real, non-deprecated Danbooru tag. "
"Tags with higher post_count are more commonly used and well-supported."
),
)
log.info(
"MCP server initialised",
extra={
"db_path": str(DB_PATH),
"db_exists": DB_PATH.exists(),
"log_level": LOG_LEVEL,
"log_format": LOG_FORMAT,
},
)
# ---------------------------------------------------------------------------
# Tool: search_tags
# ---------------------------------------------------------------------------
@mcp.tool()
def search_tags(query: str, limit: int = 20, category: str | None = None) -> list[dict]:
"""Search for Danbooru tags matching a query string.
Uses FTS5 full-text and prefix search. Results are ordered by FTS5
relevance, then by post count (most-used first).
FTS5 query syntax is supported:
- Prefix (default): "blue_ha" matches "blue_hair", "blue_hat", …
- Explicit prefix: "blue_ha*"
- Phrase: '"long hair"'
- Boolean: "hair AND blue"
Args:
query: The search string. A trailing '*' wildcard is added
automatically unless the query already ends with one.
limit: Maximum results to return (default 20, max 200).
category: Optional category filter. One of:
"general", "artist", "copyright", "character", "meta"
Returns:
List of tag objects, each with:
name (str) the exact Danbooru tag string
post_count (int) number of posts using this tag
category_name (str) "general" | "artist" | "copyright" | "character" | "meta"
is_deprecated (bool) whether the tag has been deprecated on Danbooru
"""
t0 = time.perf_counter()
limit = min(max(1, limit), 200)
log.info(
"search_tags called",
extra={"query": query, "limit": limit, "category": category},
)
fts_query = query.strip()
if fts_query and not fts_query.endswith("*"):
fts_query = fts_query + "*"
CATEGORY_MAP = {
"general": 0, "artist": 1, "copyright": 3, "character": 4, "meta": 5
}
category_filter = ""
params: list = [fts_query]
if category and category.lower() in CATEGORY_MAP:
category_filter = "AND t.category = ?"
params.append(CATEGORY_MAP[category.lower()])
params.append(limit)
try:
with _get_conn() as conn:
rows = conn.execute(
f"""
SELECT t.name, t.post_count, t.category_name,
CAST(t.is_deprecated AS INTEGER) AS is_deprecated
FROM tags_fts f
JOIN tags t ON t.id = f.rowid
WHERE tags_fts MATCH ?
{category_filter}
ORDER BY rank, t.post_count DESC
LIMIT ?
""",
params,
).fetchall()
results = [
{
"name": row["name"],
"post_count": row["post_count"],
"category": row["category_name"],
"is_deprecated": bool(row["is_deprecated"]),
}
for row in rows
]
elapsed_ms = (time.perf_counter() - t0) * 1000
log.info(
"search_tags completed",
extra={
"query": query,
"fts_query": fts_query,
"category": category,
"results": len(results),
"elapsed_ms": round(elapsed_ms, 2),
},
)
return results
except Exception:
log.exception("search_tags failed", extra={"query": query})
raise
# ---------------------------------------------------------------------------
# Tool: validate_tags
# ---------------------------------------------------------------------------
@mcp.tool()
def validate_tags(tags: list[str]) -> dict:
"""Validate a list of Danbooru tags, returning valid and invalid sets.
Performs exact-match lookup against the full Danbooru tag database.
Also flags deprecated tags — they technically exist but should be
replaced with their canonical equivalents.
Use this before submitting a prompt to Stable Diffusion.
Args:
tags: A list of tag strings to validate (e.g. ["blue_hair", "1girl"]).
Returns:
A dict with three keys:
"valid" tags that exist and are not deprecated
"deprecated" tags that exist but are deprecated (should be replaced)
"invalid" tags that were not found (misspelled or invented)
"""
t0 = time.perf_counter()
log.info(
"validate_tags called",
extra={"tag_count": len(tags), "tags_sample": tags[:5]},
)
if not tags:
log.debug("validate_tags: empty input, returning early")
return {"valid": [], "deprecated": [], "invalid": []}
# Deduplicate, preserve order
seen: dict[str, None] = {}
for t in tags:
seen[t.strip()] = None
unique_tags = [t for t in seen if t]
placeholders = ",".join("?" * len(unique_tags))
try:
with _get_conn() as conn:
rows = conn.execute(
f"""
SELECT name, is_deprecated
FROM tags
WHERE name IN ({placeholders})
""",
unique_tags,
).fetchall()
found: dict[str, bool] = {
row["name"]: bool(row["is_deprecated"]) for row in rows
}
valid = [t for t in unique_tags if t in found and not found[t]]
deprecated = [t for t in unique_tags if t in found and found[t]]
invalid = [t for t in unique_tags if t not in found]
elapsed_ms = (time.perf_counter() - t0) * 1000
log.info(
"validate_tags completed",
extra={
"total": len(unique_tags),
"valid": len(valid),
"deprecated": len(deprecated),
"invalid": len(invalid),
"invalid_tags": invalid[:10], # log first 10 invalid for debugging
"elapsed_ms": round(elapsed_ms, 2),
},
)
return {"valid": valid, "deprecated": deprecated, "invalid": invalid}
except Exception:
log.exception("validate_tags failed", extra={"tags_sample": tags[:5]})
raise
# ---------------------------------------------------------------------------
# Tool: suggest_tags
# ---------------------------------------------------------------------------
@mcp.tool()
def suggest_tags(partial: str, limit: int = 10, category: str | None = None) -> list[dict]:
"""Get tag suggestions for a partial or approximate tag input.
Runs a prefix search against the FTS5 index to find the closest
existing Danbooru tags, ordered by post count (most-used first).
Useful when the LLM is unsure of the exact spelling or wants to
explore available tags for a concept.
Deprecated tags are excluded from suggestions by default.
Args:
partial: A partial tag string (e.g. "blue_ha""blue_hair").
limit: Maximum suggestions to return (default 10, max 50).
category: Optional category filter. One of:
"general", "artist", "copyright", "character", "meta"
Returns:
List of tag objects (same shape as search_tags), sorted by
post_count descending. Deprecated tags are excluded.
"""
t0 = time.perf_counter()
limit = min(max(1, limit), 50)
log.info(
"suggest_tags called",
extra={"partial": partial, "limit": limit, "category": category},
)
fts_query = partial.strip()
if not fts_query:
log.debug("suggest_tags: empty partial, returning early")
return []
if not fts_query.endswith("*"):
fts_query = fts_query + "*"
CATEGORY_MAP = {
"general": 0, "artist": 1, "copyright": 3, "character": 4, "meta": 5
}
category_filter = ""
params: list = [fts_query]
if category and category.lower() in CATEGORY_MAP:
category_filter = "AND t.category = ?"
params.append(CATEGORY_MAP[category.lower()])
params.append(limit)
try:
with _get_conn() as conn:
rows = conn.execute(
f"""
SELECT t.name, t.post_count, t.category_name,
CAST(t.is_deprecated AS INTEGER) AS is_deprecated
FROM tags_fts f
JOIN tags t ON t.id = f.rowid
WHERE tags_fts MATCH ?
AND t.is_deprecated = 0
{category_filter}
ORDER BY t.post_count DESC
LIMIT ?
""",
params,
).fetchall()
results = [
{
"name": row["name"],
"post_count": row["post_count"],
"category": row["category_name"],
"is_deprecated": False,
}
for row in rows
]
elapsed_ms = (time.perf_counter() - t0) * 1000
log.info(
"suggest_tags completed",
extra={
"partial": partial,
"fts_query": fts_query,
"category": category,
"results": len(results),
"elapsed_ms": round(elapsed_ms, 2),
},
)
return results
except Exception:
log.exception("suggest_tags failed", extra={"partial": partial})
raise
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main() -> None:
log.info("Starting MCP server (stdio transport)")
mcp.run()
if __name__ == "__main__":
main()