#!/usr/bin/env python3 """Back-fill the `nsfw` flag and prompt details for existing records in `image_metadata`. Run this once after adding / changing the keyword list in ``config.NSFW_KEYWORDS``. Usage:: python update_nsfw_flags.py The script will: 1. Ensure the ``nsfw`` column exists (creating it if necessary). 2. Reset all flags to 0. 3. Scan every row of ``image_metadata``. 4. Set ``nsfw`` to **1** if any keyword is present in the image's ``prompt_data`` (case-insensitive), otherwise **0**. 5. Report how many rows were updated. It is safe to re-run – unchanged rows are left untouched. In addition this script will run the prompt parser for every image that has `prompt_data` so that the `prompt_details` table (including new *loras* and *textual inversions* columns) is fully populated. """ from __future__ import annotations import sqlite3 import sys from typing import List from config import DB_PATH, NSFW_KEYWORDS import database as db # Optional progress bar support try: from tqdm import tqdm # type: ignore except ImportError: # graceful fallback if tqdm not installed tqdm = None # --------------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------------- def _ensure_column_exists(cur: sqlite3.Cursor) -> None: """Create the ``nsfw`` column if an old DB is missing it.""" cur.execute("PRAGMA table_info(image_metadata)") cols = {row[1] for row in cur.fetchall()} if "nsfw" not in cols: print("Adding missing 'nsfw' column to image_metadata …", flush=True) cur.execute("ALTER TABLE image_metadata ADD COLUMN nsfw INTEGER DEFAULT 0") def _determine_nsfw(prompt: str | None, keywords: List[str]) -> int: if not prompt: return 1 lower_prompt = prompt.lower() return int(any(kw.lower() in lower_prompt for kw in keywords)) # --------------------------------------------------------------------------- # Main routine # --------------------------------------------------------------------------- def main() -> None: # Ensure all tables/columns exist (especially prompt_details extra cols) db.init_db() conn = sqlite3.connect(DB_PATH) cur = conn.cursor() _ensure_column_exists(cur) # ------------------------------------------------------------------ # Run prompt parser for every row having prompt_data (ensures prompt_details populated) # ------------------------------------------------------------------ cur.execute("SELECT path, prompt_data FROM image_metadata WHERE prompt_data IS NOT NULL") prompt_rows = cur.fetchall() total_prompts = len(prompt_rows) parsed = 0 if tqdm: iterable_prompts = tqdm(prompt_rows, desc="Parsing prompts", unit="img") # type: ignore[arg-type] else: iterable_prompts = prompt_rows for idx, (img_path, prompt) in enumerate(iterable_prompts): # type: ignore[assignment] if prompt: db.parse_and_store_prompt_details(img_path, prompt) parsed += 1 if not tqdm and (idx + 1) % 1000 == 0: print(f"Parsing prompts: {idx + 1}/{total_prompts} done…", flush=True) # ------------------------------------------------------------------ # Determine NSFW based on POSITIVE PROMPT only # ------------------------------------------------------------------ # First reset all flags to 0 so removed keywords clear previous NSFW marks cur.execute("UPDATE image_metadata SET nsfw = 0") cur.execute(""" SELECT meta.id, COALESCE(pd.positive_prompt, '') AS positive_prompt, meta.nsfw FROM image_metadata meta LEFT JOIN prompt_details pd ON meta.path = pd.image_path """) rows = cur.fetchall() total_rows = len(rows) updated = 0 if tqdm: iterable_rows = tqdm(rows, desc="NSFW tagging", unit="img") # type: ignore[arg-type] else: iterable_rows = rows for idx, (img_id, positive_prompt, current_flag) in enumerate(iterable_rows): # type: ignore[assignment] desired_flag = _determine_nsfw(positive_prompt, NSFW_KEYWORDS) if desired_flag != current_flag: cur.execute("UPDATE image_metadata SET nsfw = ? WHERE id = ?", (desired_flag, img_id)) updated += 1 if not tqdm and (idx + 1) % 1000 == 0: print(f"NSFW tagging: {idx + 1}/{total_rows} done…", flush=True) conn.commit() conn.close() print(f"NSFW flags updated for {updated} images.") print(f"Prompt details parsed for {parsed} images.") if __name__ == "__main__": try: main() except KeyboardInterrupt: sys.exit(130)