Initial commit
This commit is contained in:
445
src/server.py
Normal file
445
src/server.py
Normal file
@@ -0,0 +1,445 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Danbooru Tag Validator — MCP Server
|
||||
|
||||
Exposes three tools for LLMs to work with Danbooru tags:
|
||||
|
||||
search_tags – prefix / full-text search, returns rich tag objects
|
||||
validate_tags – exact-match check, returns valid/invalid split with metadata
|
||||
suggest_tags – autocomplete-style suggestions for a partial input
|
||||
|
||||
The SQLite database must be pre-built with scripts/scrape_tags.py.
|
||||
|
||||
Logging:
|
||||
All log output goes to stderr (stdout is reserved for the MCP JSON-RPC protocol).
|
||||
Log level is controlled by the LOG_LEVEL environment variable:
|
||||
DEBUG, INFO (default), WARNING, ERROR, CRITICAL
|
||||
Log format is controlled by LOG_FORMAT:
|
||||
"text" (default) — human-readable with timestamps
|
||||
"json" — structured JSON, one object per line (for log aggregators)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging setup — must go to stderr (stdout is the MCP transport)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
LOG_FORMAT = os.environ.get("LOG_FORMAT", "text").lower()
|
||||
|
||||
|
||||
class _JsonFormatter(logging.Formatter):
|
||||
"""Emit one JSON object per log record to stderr."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str: # noqa: A003
|
||||
obj: dict = {
|
||||
"ts": self.formatTime(record, "%Y-%m-%dT%H:%M:%S"),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"msg": record.getMessage(),
|
||||
}
|
||||
if record.exc_info:
|
||||
obj["exc"] = self.formatException(record.exc_info)
|
||||
# Forward any extra keyword args as top-level fields
|
||||
for k, v in record.__dict__.items():
|
||||
if k not in {
|
||||
"name", "msg", "args", "levelname", "levelno", "pathname",
|
||||
"filename", "module", "exc_info", "exc_text", "stack_info",
|
||||
"lineno", "funcName", "created", "msecs", "relativeCreated",
|
||||
"thread", "threadName", "processName", "process", "message",
|
||||
"taskName",
|
||||
}:
|
||||
obj[k] = v
|
||||
return json.dumps(obj, default=str)
|
||||
|
||||
|
||||
def _configure_logging() -> logging.Logger:
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setLevel(LOG_LEVEL)
|
||||
|
||||
if LOG_FORMAT == "json":
|
||||
handler.setFormatter(_JsonFormatter())
|
||||
else:
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
fmt="%(asctime)s [%(levelname)-8s] %(name)s — %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
)
|
||||
|
||||
root = logging.getLogger()
|
||||
root.setLevel(LOG_LEVEL)
|
||||
root.addHandler(handler)
|
||||
|
||||
return logging.getLogger("danbooru_mcp")
|
||||
|
||||
|
||||
log = _configure_logging()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HERE = Path(__file__).parent
|
||||
_DEFAULT_DB = _HERE.parent / "db" / "tags.db"
|
||||
DB_PATH = Path(os.environ.get("DANBOORU_TAGS_DB", str(_DEFAULT_DB)))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_db() -> None:
|
||||
if not DB_PATH.exists():
|
||||
log.error("Database file not found", extra={"db_path": str(DB_PATH)})
|
||||
raise FileNotFoundError(
|
||||
f"Tags database not found at {DB_PATH}. "
|
||||
"Run `python scripts/scrape_tags.py` first to build it."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
_check_db()
|
||||
t0 = time.perf_counter()
|
||||
conn = sqlite3.connect(str(DB_PATH), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA query_only = ON")
|
||||
conn.execute("PRAGMA cache_size = -64000") # 64 MB page cache
|
||||
log.debug("DB connection opened", extra={"db_path": str(DB_PATH)})
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
log.debug("DB connection closed", extra={"elapsed_ms": round(elapsed_ms, 2)})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP Server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
mcp = FastMCP(
|
||||
"danbooru-tags",
|
||||
instructions=(
|
||||
"Use this server to validate, search, and suggest Danbooru tags "
|
||||
"for Stable Diffusion / Illustrious prompts. "
|
||||
"Always call validate_tags before finalising a prompt to confirm "
|
||||
"every tag is a real, non-deprecated Danbooru tag. "
|
||||
"Tags with higher post_count are more commonly used and well-supported."
|
||||
),
|
||||
)
|
||||
|
||||
log.info(
|
||||
"MCP server initialised",
|
||||
extra={
|
||||
"db_path": str(DB_PATH),
|
||||
"db_exists": DB_PATH.exists(),
|
||||
"log_level": LOG_LEVEL,
|
||||
"log_format": LOG_FORMAT,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: search_tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def search_tags(query: str, limit: int = 20, category: str | None = None) -> list[dict]:
|
||||
"""Search for Danbooru tags matching a query string.
|
||||
|
||||
Uses FTS5 full-text and prefix search. Results are ordered by FTS5
|
||||
relevance, then by post count (most-used first).
|
||||
|
||||
FTS5 query syntax is supported:
|
||||
- Prefix (default): "blue_ha" matches "blue_hair", "blue_hat", …
|
||||
- Explicit prefix: "blue_ha*"
|
||||
- Phrase: '"long hair"'
|
||||
- Boolean: "hair AND blue"
|
||||
|
||||
Args:
|
||||
query: The search string. A trailing '*' wildcard is added
|
||||
automatically unless the query already ends with one.
|
||||
limit: Maximum results to return (default 20, max 200).
|
||||
category: Optional category filter. One of:
|
||||
"general", "artist", "copyright", "character", "meta"
|
||||
|
||||
Returns:
|
||||
List of tag objects, each with:
|
||||
name (str) – the exact Danbooru tag string
|
||||
post_count (int) – number of posts using this tag
|
||||
category_name (str) – "general" | "artist" | "copyright" | "character" | "meta"
|
||||
is_deprecated (bool) – whether the tag has been deprecated on Danbooru
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
limit = min(max(1, limit), 200)
|
||||
|
||||
log.info(
|
||||
"search_tags called",
|
||||
extra={"query": query, "limit": limit, "category": category},
|
||||
)
|
||||
|
||||
fts_query = query.strip()
|
||||
if fts_query and not fts_query.endswith("*"):
|
||||
fts_query = fts_query + "*"
|
||||
|
||||
CATEGORY_MAP = {
|
||||
"general": 0, "artist": 1, "copyright": 3, "character": 4, "meta": 5
|
||||
}
|
||||
category_filter = ""
|
||||
params: list = [fts_query]
|
||||
|
||||
if category and category.lower() in CATEGORY_MAP:
|
||||
category_filter = "AND t.category = ?"
|
||||
params.append(CATEGORY_MAP[category.lower()])
|
||||
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT t.name, t.post_count, t.category_name,
|
||||
CAST(t.is_deprecated AS INTEGER) AS is_deprecated
|
||||
FROM tags_fts f
|
||||
JOIN tags t ON t.id = f.rowid
|
||||
WHERE tags_fts MATCH ?
|
||||
{category_filter}
|
||||
ORDER BY rank, t.post_count DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
results = [
|
||||
{
|
||||
"name": row["name"],
|
||||
"post_count": row["post_count"],
|
||||
"category": row["category_name"],
|
||||
"is_deprecated": bool(row["is_deprecated"]),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
log.info(
|
||||
"search_tags completed",
|
||||
extra={
|
||||
"query": query,
|
||||
"fts_query": fts_query,
|
||||
"category": category,
|
||||
"results": len(results),
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
},
|
||||
)
|
||||
return results
|
||||
|
||||
except Exception:
|
||||
log.exception("search_tags failed", extra={"query": query})
|
||||
raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: validate_tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def validate_tags(tags: list[str]) -> dict:
|
||||
"""Validate a list of Danbooru tags, returning valid and invalid sets.
|
||||
|
||||
Performs exact-match lookup against the full Danbooru tag database.
|
||||
Also flags deprecated tags — they technically exist but should be
|
||||
replaced with their canonical equivalents.
|
||||
|
||||
Use this before submitting a prompt to Stable Diffusion.
|
||||
|
||||
Args:
|
||||
tags: A list of tag strings to validate (e.g. ["blue_hair", "1girl"]).
|
||||
|
||||
Returns:
|
||||
A dict with three keys:
|
||||
"valid" – tags that exist and are not deprecated
|
||||
"deprecated" – tags that exist but are deprecated (should be replaced)
|
||||
"invalid" – tags that were not found (misspelled or invented)
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
log.info(
|
||||
"validate_tags called",
|
||||
extra={"tag_count": len(tags), "tags_sample": tags[:5]},
|
||||
)
|
||||
|
||||
if not tags:
|
||||
log.debug("validate_tags: empty input, returning early")
|
||||
return {"valid": [], "deprecated": [], "invalid": []}
|
||||
|
||||
# Deduplicate, preserve order
|
||||
seen: dict[str, None] = {}
|
||||
for t in tags:
|
||||
seen[t.strip()] = None
|
||||
unique_tags = [t for t in seen if t]
|
||||
|
||||
placeholders = ",".join("?" * len(unique_tags))
|
||||
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT name, is_deprecated
|
||||
FROM tags
|
||||
WHERE name IN ({placeholders})
|
||||
""",
|
||||
unique_tags,
|
||||
).fetchall()
|
||||
|
||||
found: dict[str, bool] = {
|
||||
row["name"]: bool(row["is_deprecated"]) for row in rows
|
||||
}
|
||||
|
||||
valid = [t for t in unique_tags if t in found and not found[t]]
|
||||
deprecated = [t for t in unique_tags if t in found and found[t]]
|
||||
invalid = [t for t in unique_tags if t not in found]
|
||||
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
log.info(
|
||||
"validate_tags completed",
|
||||
extra={
|
||||
"total": len(unique_tags),
|
||||
"valid": len(valid),
|
||||
"deprecated": len(deprecated),
|
||||
"invalid": len(invalid),
|
||||
"invalid_tags": invalid[:10], # log first 10 invalid for debugging
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
},
|
||||
)
|
||||
|
||||
return {"valid": valid, "deprecated": deprecated, "invalid": invalid}
|
||||
|
||||
except Exception:
|
||||
log.exception("validate_tags failed", extra={"tags_sample": tags[:5]})
|
||||
raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: suggest_tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
def suggest_tags(partial: str, limit: int = 10, category: str | None = None) -> list[dict]:
|
||||
"""Get tag suggestions for a partial or approximate tag input.
|
||||
|
||||
Runs a prefix search against the FTS5 index to find the closest
|
||||
existing Danbooru tags, ordered by post count (most-used first).
|
||||
Useful when the LLM is unsure of the exact spelling or wants to
|
||||
explore available tags for a concept.
|
||||
|
||||
Deprecated tags are excluded from suggestions by default.
|
||||
|
||||
Args:
|
||||
partial: A partial tag string (e.g. "blue_ha" → "blue_hair").
|
||||
limit: Maximum suggestions to return (default 10, max 50).
|
||||
category: Optional category filter. One of:
|
||||
"general", "artist", "copyright", "character", "meta"
|
||||
|
||||
Returns:
|
||||
List of tag objects (same shape as search_tags), sorted by
|
||||
post_count descending. Deprecated tags are excluded.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
limit = min(max(1, limit), 50)
|
||||
|
||||
log.info(
|
||||
"suggest_tags called",
|
||||
extra={"partial": partial, "limit": limit, "category": category},
|
||||
)
|
||||
|
||||
fts_query = partial.strip()
|
||||
if not fts_query:
|
||||
log.debug("suggest_tags: empty partial, returning early")
|
||||
return []
|
||||
if not fts_query.endswith("*"):
|
||||
fts_query = fts_query + "*"
|
||||
|
||||
CATEGORY_MAP = {
|
||||
"general": 0, "artist": 1, "copyright": 3, "character": 4, "meta": 5
|
||||
}
|
||||
category_filter = ""
|
||||
params: list = [fts_query]
|
||||
|
||||
if category and category.lower() in CATEGORY_MAP:
|
||||
category_filter = "AND t.category = ?"
|
||||
params.append(CATEGORY_MAP[category.lower()])
|
||||
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT t.name, t.post_count, t.category_name,
|
||||
CAST(t.is_deprecated AS INTEGER) AS is_deprecated
|
||||
FROM tags_fts f
|
||||
JOIN tags t ON t.id = f.rowid
|
||||
WHERE tags_fts MATCH ?
|
||||
AND t.is_deprecated = 0
|
||||
{category_filter}
|
||||
ORDER BY t.post_count DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
results = [
|
||||
{
|
||||
"name": row["name"],
|
||||
"post_count": row["post_count"],
|
||||
"category": row["category_name"],
|
||||
"is_deprecated": False,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
log.info(
|
||||
"suggest_tags completed",
|
||||
extra={
|
||||
"partial": partial,
|
||||
"fts_query": fts_query,
|
||||
"category": category,
|
||||
"results": len(results),
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
},
|
||||
)
|
||||
return results
|
||||
|
||||
except Exception:
|
||||
log.exception("suggest_tags failed", extra={"partial": partial})
|
||||
raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
log.info("Starting MCP server (stdio transport)")
|
||||
mcp.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user