134 lines
4.6 KiB
Python
134 lines
4.6 KiB
Python
#!/usr/bin/env python3
|
||
"""Back-fill the `nsfw` flag and prompt details for existing records in `image_metadata`.
|
||
|
||
Run this once after adding / changing the keyword list in ``config.NSFW_KEYWORDS``.
|
||
|
||
Usage::
|
||
|
||
python update_nsfw_flags.py
|
||
|
||
The script will:
|
||
1. Ensure the ``nsfw`` column exists (creating it if necessary).
|
||
2. Reset all flags to 0.
|
||
3. Scan every row of ``image_metadata``.
|
||
4. Set ``nsfw`` to **1** if any keyword is present in the image's ``prompt_data``
|
||
(case-insensitive), otherwise **0**.
|
||
5. Report how many rows were updated.
|
||
|
||
It is safe to re-run – unchanged rows are left untouched.
|
||
|
||
In addition this script will run the prompt parser for every image that has `prompt_data` so that the `prompt_details` table (including new *loras* and *textual inversions* columns) is fully populated.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import sqlite3
|
||
import sys
|
||
from typing import List
|
||
|
||
from config import DB_PATH, NSFW_KEYWORDS
|
||
import database as db
|
||
|
||
# Optional progress bar support
|
||
try:
|
||
from tqdm import tqdm # type: ignore
|
||
except ImportError: # graceful fallback if tqdm not installed
|
||
tqdm = None
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helper functions
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _ensure_column_exists(cur: sqlite3.Cursor) -> None:
|
||
"""Create the ``nsfw`` column if an old DB is missing it."""
|
||
cur.execute("PRAGMA table_info(image_metadata)")
|
||
cols = {row[1] for row in cur.fetchall()}
|
||
if "nsfw" not in cols:
|
||
print("Adding missing 'nsfw' column to image_metadata …", flush=True)
|
||
cur.execute("ALTER TABLE image_metadata ADD COLUMN nsfw INTEGER DEFAULT 0")
|
||
|
||
|
||
def _determine_nsfw(prompt: str | None, keywords: List[str]) -> int:
|
||
if not prompt:
|
||
return 1
|
||
lower_prompt = prompt.lower()
|
||
return int(any(kw.lower() in lower_prompt for kw in keywords))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main routine
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main() -> None:
|
||
# Ensure all tables/columns exist (especially prompt_details extra cols)
|
||
db.init_db()
|
||
conn = sqlite3.connect(DB_PATH)
|
||
cur = conn.cursor()
|
||
|
||
_ensure_column_exists(cur)
|
||
|
||
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# Run prompt parser for every row having prompt_data (ensures prompt_details populated)
|
||
# ------------------------------------------------------------------
|
||
cur.execute("SELECT path, prompt_data FROM image_metadata WHERE prompt_data IS NOT NULL")
|
||
prompt_rows = cur.fetchall()
|
||
total_prompts = len(prompt_rows)
|
||
parsed = 0
|
||
|
||
if tqdm:
|
||
iterable_prompts = tqdm(prompt_rows, desc="Parsing prompts", unit="img") # type: ignore[arg-type]
|
||
else:
|
||
iterable_prompts = prompt_rows
|
||
|
||
for idx, (img_path, prompt) in enumerate(iterable_prompts): # type: ignore[assignment]
|
||
if prompt:
|
||
db.parse_and_store_prompt_details(img_path, prompt)
|
||
parsed += 1
|
||
if not tqdm and (idx + 1) % 1000 == 0:
|
||
print(f"Parsing prompts: {idx + 1}/{total_prompts} done…", flush=True)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Determine NSFW based on POSITIVE PROMPT only
|
||
# ------------------------------------------------------------------
|
||
# First reset all flags to 0 so removed keywords clear previous NSFW marks
|
||
cur.execute("UPDATE image_metadata SET nsfw = 0")
|
||
|
||
cur.execute("""
|
||
SELECT meta.id, COALESCE(pd.positive_prompt, '') AS positive_prompt, meta.nsfw
|
||
FROM image_metadata meta
|
||
LEFT JOIN prompt_details pd ON meta.path = pd.image_path
|
||
""")
|
||
rows = cur.fetchall()
|
||
total_rows = len(rows)
|
||
|
||
updated = 0
|
||
|
||
if tqdm:
|
||
iterable_rows = tqdm(rows, desc="NSFW tagging", unit="img") # type: ignore[arg-type]
|
||
else:
|
||
iterable_rows = rows
|
||
|
||
for idx, (img_id, positive_prompt, current_flag) in enumerate(iterable_rows): # type: ignore[assignment]
|
||
desired_flag = _determine_nsfw(positive_prompt, NSFW_KEYWORDS)
|
||
if desired_flag != current_flag:
|
||
cur.execute("UPDATE image_metadata SET nsfw = ? WHERE id = ?", (desired_flag, img_id))
|
||
updated += 1
|
||
if not tqdm and (idx + 1) % 1000 == 0:
|
||
print(f"NSFW tagging: {idx + 1}/{total_rows} done…", flush=True)
|
||
|
||
conn.commit()
|
||
|
||
conn.close()
|
||
|
||
print(f"NSFW flags updated for {updated} images.")
|
||
print(f"Prompt details parsed for {parsed} images.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
except KeyboardInterrupt:
|
||
sys.exit(130)
|