Basically wrote the whole thing.
This commit is contained in:
133
update_nsfw_flags.py
Normal file
133
update_nsfw_flags.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Back-fill the `nsfw` flag and prompt details for existing records in `image_metadata`.
|
||||
|
||||
Run this once after adding / changing the keyword list in ``config.NSFW_KEYWORDS``.
|
||||
|
||||
Usage::
|
||||
|
||||
python update_nsfw_flags.py
|
||||
|
||||
The script will:
|
||||
1. Ensure the ``nsfw`` column exists (creating it if necessary).
|
||||
2. Reset all flags to 0.
|
||||
3. Scan every row of ``image_metadata``.
|
||||
4. Set ``nsfw`` to **1** if any keyword is present in the image's ``prompt_data``
|
||||
(case-insensitive), otherwise **0**.
|
||||
5. Report how many rows were updated.
|
||||
|
||||
It is safe to re-run – unchanged rows are left untouched.
|
||||
|
||||
In addition this script will run the prompt parser for every image that has `prompt_data` so that the `prompt_details` table (including new *loras* and *textual inversions* columns) is fully populated.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from config import DB_PATH, NSFW_KEYWORDS
|
||||
import database as db
|
||||
|
||||
# Optional progress bar support
|
||||
try:
|
||||
from tqdm import tqdm # type: ignore
|
||||
except ImportError: # graceful fallback if tqdm not installed
|
||||
tqdm = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_column_exists(cur: sqlite3.Cursor) -> None:
|
||||
"""Create the ``nsfw`` column if an old DB is missing it."""
|
||||
cur.execute("PRAGMA table_info(image_metadata)")
|
||||
cols = {row[1] for row in cur.fetchall()}
|
||||
if "nsfw" not in cols:
|
||||
print("Adding missing 'nsfw' column to image_metadata …", flush=True)
|
||||
cur.execute("ALTER TABLE image_metadata ADD COLUMN nsfw INTEGER DEFAULT 0")
|
||||
|
||||
|
||||
def _determine_nsfw(prompt: str | None, keywords: List[str]) -> int:
|
||||
if not prompt:
|
||||
return 1
|
||||
lower_prompt = prompt.lower()
|
||||
return int(any(kw.lower() in lower_prompt for kw in keywords))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main routine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main() -> None:
|
||||
# Ensure all tables/columns exist (especially prompt_details extra cols)
|
||||
db.init_db()
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cur = conn.cursor()
|
||||
|
||||
_ensure_column_exists(cur)
|
||||
|
||||
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Run prompt parser for every row having prompt_data (ensures prompt_details populated)
|
||||
# ------------------------------------------------------------------
|
||||
cur.execute("SELECT path, prompt_data FROM image_metadata WHERE prompt_data IS NOT NULL")
|
||||
prompt_rows = cur.fetchall()
|
||||
total_prompts = len(prompt_rows)
|
||||
parsed = 0
|
||||
|
||||
if tqdm:
|
||||
iterable_prompts = tqdm(prompt_rows, desc="Parsing prompts", unit="img") # type: ignore[arg-type]
|
||||
else:
|
||||
iterable_prompts = prompt_rows
|
||||
|
||||
for idx, (img_path, prompt) in enumerate(iterable_prompts): # type: ignore[assignment]
|
||||
if prompt:
|
||||
db.parse_and_store_prompt_details(img_path, prompt)
|
||||
parsed += 1
|
||||
if not tqdm and (idx + 1) % 1000 == 0:
|
||||
print(f"Parsing prompts: {idx + 1}/{total_prompts} done…", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Determine NSFW based on POSITIVE PROMPT only
|
||||
# ------------------------------------------------------------------
|
||||
# First reset all flags to 0 so removed keywords clear previous NSFW marks
|
||||
cur.execute("UPDATE image_metadata SET nsfw = 0")
|
||||
|
||||
cur.execute("""
|
||||
SELECT meta.id, COALESCE(pd.positive_prompt, '') AS positive_prompt, meta.nsfw
|
||||
FROM image_metadata meta
|
||||
LEFT JOIN prompt_details pd ON meta.path = pd.image_path
|
||||
""")
|
||||
rows = cur.fetchall()
|
||||
total_rows = len(rows)
|
||||
|
||||
updated = 0
|
||||
|
||||
if tqdm:
|
||||
iterable_rows = tqdm(rows, desc="NSFW tagging", unit="img") # type: ignore[arg-type]
|
||||
else:
|
||||
iterable_rows = rows
|
||||
|
||||
for idx, (img_id, positive_prompt, current_flag) in enumerate(iterable_rows): # type: ignore[assignment]
|
||||
desired_flag = _determine_nsfw(positive_prompt, NSFW_KEYWORDS)
|
||||
if desired_flag != current_flag:
|
||||
cur.execute("UPDATE image_metadata SET nsfw = ? WHERE id = ?", (desired_flag, img_id))
|
||||
updated += 1
|
||||
if not tqdm and (idx + 1) % 1000 == 0:
|
||||
print(f"NSFW tagging: {idx + 1}/{total_rows} done…", flush=True)
|
||||
|
||||
conn.commit()
|
||||
|
||||
conn.close()
|
||||
|
||||
print(f"NSFW flags updated for {updated} images.")
|
||||
print(f"Prompt details parsed for {parsed} images.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(130)
|
||||
Reference in New Issue
Block a user