Files
swiper/update_nsfw_flags.py
2025-06-25 04:21:13 +01:00

134 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""Back-fill the `nsfw` flag and prompt details for existing records in `image_metadata`.
Run this once after adding / changing the keyword list in ``config.NSFW_KEYWORDS``.
Usage::
python update_nsfw_flags.py
The script will:
1. Ensure the ``nsfw`` column exists (creating it if necessary).
2. Reset all flags to 0.
3. Scan every row of ``image_metadata``.
4. Set ``nsfw`` to **1** if any keyword is present in the image's ``prompt_data``
(case-insensitive), otherwise **0**.
5. Report how many rows were updated.
It is safe to re-run unchanged rows are left untouched.
In addition this script will run the prompt parser for every image that has `prompt_data` so that the `prompt_details` table (including new *loras* and *textual inversions* columns) is fully populated.
"""
from __future__ import annotations
import sqlite3
import sys
from typing import List
from config import DB_PATH, NSFW_KEYWORDS
import database as db
# Optional progress bar support
try:
from tqdm import tqdm # type: ignore
except ImportError: # graceful fallback if tqdm not installed
tqdm = None
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
def _ensure_column_exists(cur: sqlite3.Cursor) -> None:
"""Create the ``nsfw`` column if an old DB is missing it."""
cur.execute("PRAGMA table_info(image_metadata)")
cols = {row[1] for row in cur.fetchall()}
if "nsfw" not in cols:
print("Adding missing 'nsfw' column to image_metadata …", flush=True)
cur.execute("ALTER TABLE image_metadata ADD COLUMN nsfw INTEGER DEFAULT 0")
def _determine_nsfw(prompt: str | None, keywords: List[str]) -> int:
if not prompt:
return 1
lower_prompt = prompt.lower()
return int(any(kw.lower() in lower_prompt for kw in keywords))
# ---------------------------------------------------------------------------
# Main routine
# ---------------------------------------------------------------------------
def main() -> None:
# Ensure all tables/columns exist (especially prompt_details extra cols)
db.init_db()
conn = sqlite3.connect(DB_PATH)
cur = conn.cursor()
_ensure_column_exists(cur)
# ------------------------------------------------------------------
# Run prompt parser for every row having prompt_data (ensures prompt_details populated)
# ------------------------------------------------------------------
cur.execute("SELECT path, prompt_data FROM image_metadata WHERE prompt_data IS NOT NULL")
prompt_rows = cur.fetchall()
total_prompts = len(prompt_rows)
parsed = 0
if tqdm:
iterable_prompts = tqdm(prompt_rows, desc="Parsing prompts", unit="img") # type: ignore[arg-type]
else:
iterable_prompts = prompt_rows
for idx, (img_path, prompt) in enumerate(iterable_prompts): # type: ignore[assignment]
if prompt:
db.parse_and_store_prompt_details(img_path, prompt)
parsed += 1
if not tqdm and (idx + 1) % 1000 == 0:
print(f"Parsing prompts: {idx + 1}/{total_prompts} done…", flush=True)
# ------------------------------------------------------------------
# Determine NSFW based on POSITIVE PROMPT only
# ------------------------------------------------------------------
# First reset all flags to 0 so removed keywords clear previous NSFW marks
cur.execute("UPDATE image_metadata SET nsfw = 0")
cur.execute("""
SELECT meta.id, COALESCE(pd.positive_prompt, '') AS positive_prompt, meta.nsfw
FROM image_metadata meta
LEFT JOIN prompt_details pd ON meta.path = pd.image_path
""")
rows = cur.fetchall()
total_rows = len(rows)
updated = 0
if tqdm:
iterable_rows = tqdm(rows, desc="NSFW tagging", unit="img") # type: ignore[arg-type]
else:
iterable_rows = rows
for idx, (img_id, positive_prompt, current_flag) in enumerate(iterable_rows): # type: ignore[assignment]
desired_flag = _determine_nsfw(positive_prompt, NSFW_KEYWORDS)
if desired_flag != current_flag:
cur.execute("UPDATE image_metadata SET nsfw = ? WHERE id = ?", (desired_flag, img_id))
updated += 1
if not tqdm and (idx + 1) % 1000 == 0:
print(f"NSFW tagging: {idx + 1}/{total_rows} done…", flush=True)
conn.commit()
conn.close()
print(f"NSFW flags updated for {updated} images.")
print(f"Prompt details parsed for {parsed} images.")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit(130)