Files
swiper/database.py
2025-06-25 22:26:49 +01:00

530 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Database helper functions for SWIPER.
All DB reads / writes are centralised here so the rest of the
codebase never needs to know SQL details.
"""
from __future__ import annotations
import os
import sqlite3
import json
import time
from typing import List, Dict, Any
# Optional progress bar
try:
from tqdm import tqdm # type: ignore
except ImportError: # pragma: no cover
tqdm = None
from PIL import Image
from config import DB_PATH, IMAGE_DIRS, find_image_file, NSFW_KEYWORDS
# ---------------------------------------------------------------------------
# Core helpers
# ---------------------------------------------------------------------------
def _get_conn() -> sqlite3.Connection:
"""Return a new connection with row access by column name."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
# ---------------------------------------------------------------------------
# Schema setup & sync helpers
# ---------------------------------------------------------------------------
def init_db() -> None:
"""Create missing tables and perform schema migrations if necessary."""
conn = _get_conn()
cur = conn.cursor()
# Ensure image_selections table exists
cur.execute(
"""
CREATE TABLE IF NOT EXISTS image_selections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_path TEXT NOT NULL UNIQUE,
action TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
"""
)
# Ensure image_metadata table exists
cur.execute(
"""
CREATE TABLE IF NOT EXISTS image_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE,
resolution_x INTEGER NOT NULL,
resolution_y INTEGER NOT NULL,
name TEXT NOT NULL,
orientation TEXT NOT NULL,
creation_date INTEGER NOT NULL,
prompt_data TEXT,
actioned TEXT DEFAULT NULL,
nsfw INTEGER NOT NULL DEFAULT 0
)
"""
)
# Ensure prompt_details table exists and new columns
cur.execute(
"""
CREATE TABLE IF NOT EXISTS prompt_details (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_path TEXT NOT NULL UNIQUE,
model_name TEXT,
positive_prompt TEXT,
negative_prompt TEXT,
sampler TEXT,
steps INTEGER,
cfg_scale REAL,
seed INTEGER,
clip_skip INTEGER,
loras TEXT,
textual_inversions TEXT,
other_parameters TEXT,
FOREIGN KEY (image_path) REFERENCES image_metadata (path)
)
"""
)
# ------------------------------------------------------------------
# Ensure newer optional columns exist / are migrated for prompt_details
# ------------------------------------------------------------------
cur.execute("PRAGMA table_info(prompt_details)")
pd_columns = {row['name'] for row in cur.fetchall()}
if 'loras' not in pd_columns:
cur.execute('ALTER TABLE prompt_details ADD COLUMN loras TEXT')
if 'textual_inversions' not in pd_columns:
cur.execute('ALTER TABLE prompt_details ADD COLUMN textual_inversions TEXT')
# Check and migrate the 'actioned' column
cur.execute("PRAGMA table_info(image_metadata)")
columns = {row['name']: row['type'] for row in cur.fetchall()}
# ------------------------------------------------------------------
# Ensure newer optional columns exist / are migrated
# ------------------------------------------------------------------
if 'actioned' not in columns:
# Add the column if it doesn't exist
cur.execute('ALTER TABLE image_metadata ADD COLUMN actioned TEXT DEFAULT NULL')
elif columns.get('actioned') != 'TEXT':
# Migrate the column if it has the wrong type (e.g., INTEGER)
print("Migrating 'actioned' column to TEXT type...")
cur.execute('ALTER TABLE image_metadata RENAME TO image_metadata_old')
cur.execute(
"""
CREATE TABLE image_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE,
resolution_x INTEGER NOT NULL,
resolution_y INTEGER NOT NULL,
name TEXT NOT NULL,
orientation TEXT NOT NULL,
creation_date INTEGER NOT NULL,
prompt_data TEXT,
actioned TEXT DEFAULT NULL,
nsfw INTEGER DEFAULT 0
)
"""
)
# Copy data, omitting the old 'actioned' column
cur.execute(
"""
INSERT INTO image_metadata (id, path, resolution_x, resolution_y, name, orientation, creation_date, prompt_data, nsfw)
SELECT id, path, resolution_x, resolution_y, name, orientation, creation_date, prompt_data, 0
FROM image_metadata_old
"""
)
cur.execute('DROP TABLE image_metadata_old')
print("Migration complete.")
# Ensure nsfw column exists for older installations
if 'nsfw' not in columns:
cur.execute('ALTER TABLE image_metadata ADD COLUMN nsfw INTEGER DEFAULT 0')
# Migrate the column if it has the wrong type (e.g., INTEGER)
print("Migrating 'actioned' column to TEXT type...")
cur.execute('ALTER TABLE image_metadata RENAME TO image_metadata_old')
cur.execute(
"""
CREATE TABLE image_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE,
resolution_x INTEGER NOT NULL,
resolution_y INTEGER NOT NULL,
name TEXT NOT NULL,
orientation TEXT NOT NULL,
creation_date INTEGER NOT NULL,
prompt_data TEXT,
actioned TEXT DEFAULT NULL,
nsfw INTEGER DEFAULT 0
)
"""
)
# Copy data, omitting the old 'actioned' column
cur.execute(
"""
INSERT INTO image_metadata (id, path, resolution_x, resolution_y, name, orientation, creation_date, prompt_data)
SELECT id, path, resolution_x, resolution_y, name, orientation, creation_date, prompt_data
FROM image_metadata_old
"""
)
cur.execute('DROP TABLE image_metadata_old')
print("Migration complete.")
conn.commit()
conn.close()
def parse_and_store_prompt_details(image_path: str, prompt_data: str) -> None:
"""Parse prompt data and store it in the prompt_details table."""
if not prompt_data:
return
conn = _get_conn()
cur = conn.cursor()
# Simple parsing logic, can be expanded
details = {
"image_path": image_path,
"positive_prompt": "",
"negative_prompt": "",
"model_name": None,
"sampler": None,
"steps": None,
"cfg_scale": None,
"seed": None,
"clip_skip": None,
"loras": None,
"textual_inversions": None,
"other_parameters": "{}"
}
try:
import re
text = prompt_data
# Positive & negative prompt
neg_match = re.search(r"Negative prompt:\s*(.*?)\s*(Steps:|Sampler:|Schedule type:|CFG scale:|Seed:|Size:|Model hash:|Model:|Lora hashes:|TI:|Version:|$)", text, re.S)
if neg_match:
details["negative_prompt"] = neg_match.group(1).strip()
details["positive_prompt"] = text[:neg_match.start()].strip().rstrip(',')
else:
details["positive_prompt"] = text.split("Steps:")[0].strip()
# Keyvalue param pairs (e.g. "Steps: 20," )
for key, val in re.findall(r"(\w[\w ]*?):\s*([^,\n]+)", text):
val = val.strip().rstrip(',')
if key == "Model":
details["model_name"] = val
elif key == "Sampler":
details["sampler"] = val
elif key == "Steps":
details["steps"] = int(val or 0)
elif key == "CFG scale":
details["cfg_scale"] = float(val or 0)
elif key == "Seed":
details["seed"] = int(val or 0)
elif key == "CLIP skip":
details["clip_skip"] = int(val or 0)
# Store other params as JSON
std_keys = {"Model", "Sampler", "Steps", "CFG scale", "Seed", "CLIP skip"}
other_params = {k: v.strip().rstrip(',') for k, v in re.findall(r"(\w[\w ]*?):\s*([^,\n]+)", text) if k not in std_keys}
details["other_parameters"] = json.dumps(other_params)
# Extract Loras and Textual Inversions (TIs)
lora_match = re.search(r"Lora hashes:\s*\"?([^\"]+)\"?", text, re.I)
if lora_match:
details["loras"] = lora_match.group(1).strip()
ti_match = re.search(r"TI:\s*\"?([^\"]+)\"?", text, re.I)
if ti_match:
details["textual_inversions"] = ti_match.group(1).strip()
# Fallback: look for <lora:...> tags if no explicit Lora hashes line
if not details.get("loras"):
tag_matches = re.findall(r"<lora:([^>]+)>", text, re.I)
if tag_matches:
details["loras"] = ", ".join(tag_matches)
# Extract TI keywords (<token> patterns) if still empty
if not details.get("textual_inversions"):
ti_tokens = re.findall(r"<([\w-]{3,32})>", details["positive_prompt"]) # crude heuristic
if ti_tokens:
details["textual_inversions"] = ", ".join(sorted(set(ti_tokens)))
# Final clean-up
if details.get("textual_inversions"):
details["textual_inversions"] = details["textual_inversions"].strip().strip('"').rstrip(',').strip()
except Exception as e:
print(f"Error parsing prompt for {image_path}: {e}")
# still insert with what we have
cur.execute(
"""
INSERT OR REPLACE INTO prompt_details (image_path, model_name, positive_prompt, negative_prompt, sampler, steps, cfg_scale, seed, clip_skip, loras, textual_inversions, other_parameters)
VALUES (:image_path, :model_name, :positive_prompt, :negative_prompt, :sampler, :steps, :cfg_scale, :seed, :clip_skip, :loras, :textual_inversions, :other_parameters)
""",
details
)
conn.commit()
conn.close()
def sync_image_database() -> None:
"""Scan the image folder and ensure every image is present in image_metadata."""
print("Syncing image database…", flush=True)
conn = _get_conn()
cur = conn.cursor()
# Already-known images
cur.execute("SELECT path FROM image_metadata")
known = {row[0] for row in cur.fetchall()}
# Images on disk (expects <resolution>/<filename>) across all configured dirs
disk_images: set[str] = set()
for base in IMAGE_DIRS:
for res in [d for d in os.listdir(base) if os.path.isdir(os.path.join(base, d))]:
res_dir = os.path.join(base, res)
for file in os.listdir(res_dir):
if file.lower().endswith((".png", ".jpg", ".jpeg")):
disk_images.add(f"{res}/{file}")
new_images = disk_images - known
if not new_images:
print("Database already up to date.")
conn.close()
return
total_new_images = len(new_images)
# Choose iterator with progress bar if available
image_iter = (
tqdm(sorted(new_images), desc="Syncing images") if tqdm else sorted(new_images)
)
processed_count = 0
rows = []
for rel_path in image_iter:
res, filename = rel_path.split("/", 1)
abs_path = find_image_file(rel_path)
if not abs_path:
processed_count += 1
if not tqdm and (processed_count % 100 == 0 or processed_count == total_new_images):
percentage = (processed_count / total_new_images) * 100
print(
f"Processed {processed_count} / {total_new_images} images ({percentage:.2f}%)",
flush=True,
)
continue
try:
with Image.open(abs_path) as img:
w, h = img.size
orient = (
"square" if w == h else "landscape" if w > h else "portrait"
)
prompt = None
if img.format == "PNG":
prompt = img.info.get("parameters") or img.info.get("Parameters")
ts = int(os.path.getmtime(abs_path))
# Detect NSFW based on prompt keywords
nsfw_flag = 0
if prompt:
lower_prompt = prompt.lower()
nsfw_flag = 1 if any(k.lower() in lower_prompt for k in NSFW_KEYWORDS) else 0
rows.append((rel_path, w, h, filename, orient, ts, prompt, nsfw_flag))
except Exception as exc:
print(f"Failed reading {abs_path}: {exc}")
if rows:
cur.executemany(
"""
INSERT INTO image_metadata (path, resolution_x, resolution_y, name, orientation, creation_date, prompt_data, nsfw)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
rows,
)
conn.commit()
print(f"Inserted {len(rows)} new images.")
# Now, parse and store prompt details for the new images
for row in rows:
# row structure: (path, w, h, name, orientation, ts, prompt, nsfw_flag)
rel_path = row[0]
prompt = row[6]
if prompt:
parse_and_store_prompt_details(rel_path, prompt)
conn.close()
# ---------------------------------------------------------------------------
# Selection helpers
# ---------------------------------------------------------------------------
def add_selection(image_path: str, action: str) -> None:
"""Add or update a selection and the metadata actioned status."""
conn = _get_conn()
cur = conn.cursor()
timestamp = int(time.time())
# Upsert the selection
cur.execute(
"""
INSERT INTO image_selections (image_path, action, timestamp)
VALUES (?, ?, ?)
ON CONFLICT(image_path) DO UPDATE SET
action = excluded.action,
timestamp = excluded.timestamp
""",
(image_path, action, timestamp),
)
# Update the metadata table with the action name
cur.execute(
"UPDATE image_metadata SET actioned = ? WHERE path = ?", (action, image_path)
)
conn.commit()
conn.close()
def get_selections() -> List[Dict[str, Any]]:
"""Return selection list with metadata and parsed prompt details."""
conn = _get_conn()
cur = conn.cursor()
cur.execute(
"""
SELECT
sel.id, sel.image_path, sel.action, sel.timestamp,
meta.resolution_x, meta.resolution_y, meta.orientation,
meta.nsfw,
meta.creation_date, meta.prompt_data, meta.name,
pd.model_name, pd.positive_prompt, pd.negative_prompt,
pd.sampler, pd.steps, pd.cfg_scale, pd.seed, pd.clip_skip,
pd.loras, pd.textual_inversions,
pd.other_parameters
FROM image_selections sel
LEFT JOIN image_metadata meta ON sel.image_path = meta.path
LEFT JOIN prompt_details pd ON sel.image_path = pd.image_path
ORDER BY sel.timestamp DESC
"""
)
rows = cur.fetchall()
conn.close()
results: List[Dict[str, Any]] = []
for row in rows:
item: Dict[str, Any] = {k: row[k] for k in row.keys()}
# Parse other_parameters if it exists and is a string
other_params_str = item.get("other_parameters")
if isinstance(other_params_str, str):
try:
item["other_parameters"] = json.loads(other_params_str)
except json.JSONDecodeError:
item["other_parameters"] = {}
# Derive resolution & orientation if missing (back-compat)
if not item.get("resolution"):
try:
path_part = item["image_path"].lstrip("/images/")
item["resolution"] = path_part.split("/")[0]
except Exception:
item["resolution"] = "unknown"
if not item.get("orientation"):
try:
abs_path = find_image_file(item["image_path"].lstrip("/images/"))
if abs_path:
with Image.open(abs_path) as img:
w, h = img.size
item["orientation"] = (
"square" if w == h else "landscape" if w > h else "portrait"
)
except Exception:
item["orientation"] = "unknown"
results.append(item)
return results
def update_selection(selection_id: int, action: str) -> bool:
"""Update an existing selection and the corresponding metadata."""
conn = _get_conn()
cur = conn.cursor()
# First, get the image_path for the given selection_id
cur.execute("SELECT image_path FROM image_selections WHERE id = ?", (selection_id,))
row = cur.fetchone()
if not row:
conn.close()
return False
image_path = row['image_path']
# Update the action in the image_selections table
cur.execute(
"UPDATE image_selections SET action = ?, timestamp = ? WHERE id = ?",
(action, int(time.time()), selection_id),
)
changed = cur.rowcount > 0
if changed:
# Also update the actioned column in the image_metadata table
cur.execute(
"UPDATE image_metadata SET actioned = ? WHERE path = ?", (action, image_path)
)
conn.commit()
conn.close()
return changed
def delete_selection(selection_id: int) -> None:
"""Delete a selection and reset the metadata actioned link."""
conn = _get_conn()
cur = conn.cursor()
# Find the image path before deleting
cur.execute("SELECT image_path FROM image_selections WHERE id = ?", (selection_id,))
row = cur.fetchone()
if not row:
conn.close()
return # Or raise error
image_path = row['image_path']
# Delete the selection
cur.execute("DELETE FROM image_selections WHERE id = ?", (selection_id,))
# Update the metadata table
cur.execute("UPDATE image_metadata SET actioned = NULL WHERE path = ?", (image_path,))
conn.commit()
conn.close()
def reset_database() -> None:
"""Clear all selections and reset all actioned links in metadata."""
conn = _get_conn()
cur = conn.cursor()
cur.execute("DELETE FROM image_selections")
cur.execute("UPDATE image_metadata SET actioned = NULL")
conn.commit()
conn.close()
def reset_database() -> None:
"""Clear all selections and reset all actioned links in metadata."""
conn = _get_conn()
cur = conn.cursor()
cur.execute("DELETE FROM image_selections")
cur.execute("UPDATE image_metadata SET actioned = NULL")
conn.commit()
conn.close()