Basically wrote the whole thing.

This commit is contained in:
Aodhan
2025-06-25 04:21:13 +01:00
parent 1ff4a6f6d7
commit c5391a957d
216 changed files with 168676 additions and 1303 deletions

133
update_nsfw_flags.py Normal file
View File

@@ -0,0 +1,133 @@
#!/usr/bin/env python3
"""Back-fill the `nsfw` flag and prompt details for existing records in `image_metadata`.
Run this once after adding / changing the keyword list in ``config.NSFW_KEYWORDS``.
Usage::
python update_nsfw_flags.py
The script will:
1. Ensure the ``nsfw`` column exists (creating it if necessary).
2. Reset all flags to 0.
3. Scan every row of ``image_metadata``.
4. Set ``nsfw`` to **1** if any keyword is present in the image's ``prompt_data``
(case-insensitive), otherwise **0**.
5. Report how many rows were updated.
It is safe to re-run unchanged rows are left untouched.
In addition this script will run the prompt parser for every image that has `prompt_data` so that the `prompt_details` table (including new *loras* and *textual inversions* columns) is fully populated.
"""
from __future__ import annotations
import sqlite3
import sys
from typing import List
from config import DB_PATH, NSFW_KEYWORDS
import database as db
# Optional progress bar support
try:
from tqdm import tqdm # type: ignore
except ImportError: # graceful fallback if tqdm not installed
tqdm = None
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
def _ensure_column_exists(cur: sqlite3.Cursor) -> None:
"""Create the ``nsfw`` column if an old DB is missing it."""
cur.execute("PRAGMA table_info(image_metadata)")
cols = {row[1] for row in cur.fetchall()}
if "nsfw" not in cols:
print("Adding missing 'nsfw' column to image_metadata …", flush=True)
cur.execute("ALTER TABLE image_metadata ADD COLUMN nsfw INTEGER DEFAULT 0")
def _determine_nsfw(prompt: str | None, keywords: List[str]) -> int:
if not prompt:
return 1
lower_prompt = prompt.lower()
return int(any(kw.lower() in lower_prompt for kw in keywords))
# ---------------------------------------------------------------------------
# Main routine
# ---------------------------------------------------------------------------
def main() -> None:
# Ensure all tables/columns exist (especially prompt_details extra cols)
db.init_db()
conn = sqlite3.connect(DB_PATH)
cur = conn.cursor()
_ensure_column_exists(cur)
# ------------------------------------------------------------------
# Run prompt parser for every row having prompt_data (ensures prompt_details populated)
# ------------------------------------------------------------------
cur.execute("SELECT path, prompt_data FROM image_metadata WHERE prompt_data IS NOT NULL")
prompt_rows = cur.fetchall()
total_prompts = len(prompt_rows)
parsed = 0
if tqdm:
iterable_prompts = tqdm(prompt_rows, desc="Parsing prompts", unit="img") # type: ignore[arg-type]
else:
iterable_prompts = prompt_rows
for idx, (img_path, prompt) in enumerate(iterable_prompts): # type: ignore[assignment]
if prompt:
db.parse_and_store_prompt_details(img_path, prompt)
parsed += 1
if not tqdm and (idx + 1) % 1000 == 0:
print(f"Parsing prompts: {idx + 1}/{total_prompts} done…", flush=True)
# ------------------------------------------------------------------
# Determine NSFW based on POSITIVE PROMPT only
# ------------------------------------------------------------------
# First reset all flags to 0 so removed keywords clear previous NSFW marks
cur.execute("UPDATE image_metadata SET nsfw = 0")
cur.execute("""
SELECT meta.id, COALESCE(pd.positive_prompt, '') AS positive_prompt, meta.nsfw
FROM image_metadata meta
LEFT JOIN prompt_details pd ON meta.path = pd.image_path
""")
rows = cur.fetchall()
total_rows = len(rows)
updated = 0
if tqdm:
iterable_rows = tqdm(rows, desc="NSFW tagging", unit="img") # type: ignore[arg-type]
else:
iterable_rows = rows
for idx, (img_id, positive_prompt, current_flag) in enumerate(iterable_rows): # type: ignore[assignment]
desired_flag = _determine_nsfw(positive_prompt, NSFW_KEYWORDS)
if desired_flag != current_flag:
cur.execute("UPDATE image_metadata SET nsfw = ? WHERE id = ?", (desired_flag, img_id))
updated += 1
if not tqdm and (idx + 1) % 1000 == 0:
print(f"NSFW tagging: {idx + 1}/{total_rows} done…", flush=True)
conn.commit()
conn.close()
print(f"NSFW flags updated for {updated} images.")
print(f"Prompt details parsed for {parsed} images.")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit(130)