Files
swiper/app.py
2025-06-21 21:44:52 +01:00

757 lines
28 KiB
Python

from http.server import HTTPServer, BaseHTTPRequestHandler
import os
import json
import random
import mimetypes
import urllib.parse
import sqlite3
import time
import datetime
import zipfile
import io
from PIL import Image
# Path to the image directory
IMAGE_DIR = "/mnt/secret-items/sd-outputs/Sorted/Images"
# Database file path
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "image_selections.db")
# NOTE: We no longer delete the database on each run.
# If schema changes are needed, run a one-time migration script instead.
# Initialize the database
def init_db():
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Create image_selections table
cursor.execute('''
CREATE TABLE IF NOT EXISTS image_selections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_path TEXT NOT NULL UNIQUE,
action TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
''')
# (Re)create image_metadata table with new schema
cursor.execute('''
CREATE TABLE IF NOT EXISTS image_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE,
resolution_x INTEGER NOT NULL,
resolution_y INTEGER NOT NULL,
name TEXT NOT NULL,
orientation TEXT NOT NULL,
creation_date INTEGER NOT NULL,
prompt_data TEXT
)
''')
conn.commit()
conn.close()
print(f"Database initialized at {DB_PATH}")
# Add a selection to the database
def add_selection(image_path, action):
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Use REPLACE INTO to handle potential duplicates gracefully
cursor.execute('''
REPLACE INTO image_selections (image_path, action, timestamp)
VALUES (?, ?, ?)
''', (image_path, action, int(time.time())))
conn.commit()
conn.close()
# Get all selections from the database
def get_selections():
print("DEBUG: get_selections() called")
try:
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row # This enables column access by name
cursor = conn.cursor()
cursor.execute('''
SELECT sel.id, sel.image_path, sel.action, sel.timestamp,
meta.resolution_x, meta.resolution_y, meta.orientation,
meta.creation_date, meta.prompt_data, meta.name
FROM image_selections sel
LEFT JOIN image_metadata meta ON sel.image_path = meta.path
ORDER BY sel.timestamp DESC
''')
rows = cursor.fetchall()
print(f"DEBUG: Fetched {len(rows)} rows from database")
# Properly convert SQLite Row objects to dictionaries
results = []
for row in rows:
item = {}
for key in row.keys():
item[key] = row[key]
# Ensure resolution exists
if 'resolution' not in item or not item['resolution']:
# derive resolution from path e.g. 2048x2048
try:
path_part = item['image_path']
if path_part.startswith('/images/'):
path_part = path_part[8:]
res = path_part.split('/')[0]
item['resolution'] = res
except Exception:
item['resolution'] = "unknown"
# Ensure orientation exists
if 'orientation' not in item or not item['orientation']:
try:
# Try to determine orientation if not in database
image_path = item['image_path']
if image_path.startswith('/images/'):
image_path = image_path[8:]
full_path = os.path.join(IMAGE_DIR, image_path)
with Image.open(full_path) as img:
width, height = img.size
item['orientation'] = "portrait" if height > width else "landscape" if width > height else "square"
except Exception as e:
print(f"DEBUG ERROR determining missing orientation: {str(e)}")
item['orientation'] = "unknown"
results.append(item)
print(f"DEBUG: Converted {len(results)} rows to dictionaries")
print(f"DEBUG: First result (if any): {results[0] if results else 'None'}")
conn.close()
return results
except Exception as e:
print(f"DEBUG ERROR in get_selections(): {str(e)}")
# Return empty list on error to prevent client from hanging
return []
# Get a list of all image paths that have already been actioned
def sync_image_database():
"""Scans the image directory and adds any new images to the metadata table."""
print("Syncing image database...")
from PIL import Image
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Get all image paths already in the database
cursor.execute("SELECT path FROM image_metadata")
db_images = {row[0] for row in cursor.fetchall()}
print(f"Found {len(db_images)} images in the database.")
# Find all images on the filesystem
disk_images = set()
resolutions = [d for d in os.listdir(IMAGE_DIR) if os.path.isdir(os.path.join(IMAGE_DIR, d))]
for res in resolutions:
res_dir = os.path.join(IMAGE_DIR, res)
for img_name in os.listdir(res_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
disk_images.add(f"{res}/{img_name}")
print(f"Found {len(disk_images)} images on disk.")
# Determine which images are new
new_images = disk_images - db_images
print(f"Found {len(new_images)} new images to add to the database.")
if not new_images:
print("Database is already up-to-date.")
conn.close()
return
# Process and add new images to the database
images_to_add = []
total_new_images = len(new_images)
processed_count = 0
for image_path in new_images:
res, img_name = image_path.split('/', 1)
full_path = os.path.join(IMAGE_DIR, image_path)
try:
with Image.open(full_path) as img:
width, height = img.size
orientation = 'square' if width == height else ('landscape' if width > height else 'portrait')
# Attempt to read prompt info from PNG metadata (PNG only)
prompt_text = None
if img.format == 'PNG':
prompt_text = img.info.get('parameters') or img.info.get('Parameters')
creation_ts = int(os.path.getmtime(full_path))
images_to_add.append((image_path, width, height, img_name, orientation, creation_ts, prompt_text))
processed_count += 1
if processed_count % 100 == 0 or processed_count == total_new_images:
percentage = (processed_count / total_new_images) * 100
print(f"Processed {processed_count} of {total_new_images} images ({percentage:.2f}%)...", flush=True)
except Exception as e:
print(f"Could not process image {full_path}: {e}")
if images_to_add:
cursor.executemany('''
INSERT INTO image_metadata (path, resolution_x, resolution_y, name, orientation, creation_date, prompt_data)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', images_to_add)
conn.commit()
print(f"Successfully added {len(images_to_add)} new images to the database.")
conn.close()
# Update a selection in the database
def update_selection(selection_id, action):
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Update the selection
cursor.execute('''
UPDATE image_selections SET action = ?, timestamp = ? WHERE id = ?
''', (action, int(time.time()), selection_id))
# Check if a row was affected
rows_affected = cursor.rowcount
conn.commit()
conn.close()
return rows_affected > 0
# Delete a selection from the database
def delete_selection(selection_id):
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Delete the selection
cursor.execute('''
DELETE FROM image_selections WHERE id = ?
''', (selection_id,))
# Check if a row was affected
rows_affected = cursor.rowcount
conn.commit()
conn.close()
return rows_affected > 0
# Reset the database by deleting all selections
def reset_database():
print("DEBUG: Resetting database - deleting all selections")
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Delete all selections
cursor.execute('''
DELETE FROM image_selections
''')
# Get the number of rows affected
rows_affected = cursor.rowcount
conn.commit()
conn.close()
print(f"DEBUG: Reset database - deleted {rows_affected} selections")
return rows_affected
class ImageSwipeHandler(BaseHTTPRequestHandler):
# Set response headers for CORS
def _set_cors_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
def do_GET(self):
# Parse the URL
parsed_url = urllib.parse.urlparse(self.path)
path = parsed_url.path
# Handle different paths
if path == '/':
self.serve_file('index.html', 'text/html')
elif path == '/history':
self.serve_file('history.html', 'text/html')
elif path == '/styles.css':
self.serve_file('styles.css', 'text/css')
elif path == '/script.js':
self.serve_file('script.js', 'application/javascript')
elif path == '/random-image':
self.serve_random_image()
elif path == '/selections':
self.serve_selections()
elif path.startswith('/images/'):
self.serve_image(path[8:])
elif path == '/favicon.ico':
# Silently ignore favicon requests
self.send_response(204)
self.end_headers()
elif path.startswith('/download-selected'):
self.handle_download_selected()
else:
# Try to serve as a static file
if path.startswith('/'):
path = path[1:] # Remove leading slash
try:
self.serve_file(path)
except:
self.send_error(404, "File not found")
def do_POST(self):
parsed_url = urllib.parse.urlparse(self.path)
path = parsed_url.path
# Debug: log every POST path
print(f"DEBUG: do_POST received path='{path}'")
# Accept /selection paths
if path.startswith('/selection'):
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
data = json.loads(post_data)
print(f"DEBUG: Received selection POST: {data}")
add_selection(data['image_path'], data['action'])
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps({'status': 'success'}).encode())
except Exception as e:
print(f"ERROR in do_POST /selection: {e}")
self.send_error(500, f"Server error processing selection: {e}")
else:
print(f"DEBUG: Unknown POST path '{path}'")
self.send_error(404, "Endpoint not found")
def do_OPTIONS(self):
self.send_response(204)
self._set_cors_headers()
self.end_headers()
def serve_file(self, file_path, content_type=None):
try:
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path), 'rb') as file:
content = file.read()
self.send_response(200)
# Set the content type based on parameter or guess from file extension
if not content_type:
content_type, _ = mimetypes.guess_type(file_path)
if content_type:
self.send_header('Content-type', content_type)
else:
self.send_header('Content-type', 'application/octet-stream')
self._set_cors_headers()
self.send_header('Content-length', len(content))
self.end_headers()
self.wfile.write(content)
except FileNotFoundError:
self.send_error(404, f"File not found: {file_path}")
def serve_image(self, image_path):
try:
# Decode URL-encoded path
image_path = urllib.parse.unquote(image_path)
full_path = os.path.join(IMAGE_DIR, image_path)
with open(full_path, 'rb') as file:
content = file.read()
self.send_response(200)
# Set the content type based on file extension
content_type, _ = mimetypes.guess_type(full_path)
if content_type:
self.send_header('Content-type', content_type)
else:
self.send_header('Content-type', 'application/octet-stream')
self._set_cors_headers()
self.send_header('Content-length', len(content))
self.end_headers()
self.wfile.write(content)
except FileNotFoundError:
self.send_error(404, f"Image not found: {image_path}")
def serve_random_image(self):
try:
parsed_url = urllib.parse.urlparse(self.path)
query_params = urllib.parse.parse_qs(parsed_url.query)
orientation_filter = query_params.get('orientation', ['all'])[0]
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Base query to get unactioned images
query = """
SELECT meta.path, meta.resolution_x, meta.resolution_y, meta.name, meta.orientation, meta.creation_date, meta.prompt_data
FROM image_metadata meta
LEFT JOIN image_selections sel ON meta.path = sel.image_path
WHERE sel.image_path IS NULL
"""
# Add orientation filter if specified
params = ()
if orientation_filter != 'all':
query += " AND meta.orientation = ?"
params = (orientation_filter,)
cursor.execute(query, params)
possible_images = cursor.fetchall()
conn.close()
if not possible_images:
print("DEBUG: No matching unactioned images found.")
response = {'message': 'No more images available for this filter.'}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps(response).encode())
return
# Choose one random image from the filtered list
chosen_image_row = random.choice(possible_images)
image_path = chosen_image_row[0]
resolution_x = chosen_image_row[1]
resolution_y = chosen_image_row[2]
image_name = chosen_image_row[3]
orientation = chosen_image_row[4]
creation_ts = chosen_image_row[5]
prompt_data = chosen_image_row[6]
full_image_path = os.path.join(IMAGE_DIR, image_path)
print(f"DEBUG: Serving image: {image_path}")
# Return the image path as JSON
response = {
'path': f"/images/{image_path}",
'resolution_x': resolution_x,
'resolution_y': resolution_y,
'resolution': f"{resolution_x}x{resolution_y}",
'filename': image_name,
'creation_date': datetime.datetime.fromtimestamp(creation_ts).strftime('%Y-%m-%d %H:%M:%S'),
'prompt_data': prompt_data,
'orientation': orientation
}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps(response).encode())
except Exception as e:
print(f"FATAL ERROR in serve_random_image: {e}")
self.send_error(500, f"Error serving random image: {str(e)}")
def serve_resolutions(self):
try:
# Get all resolution directories
resolutions = [d for d in os.listdir(IMAGE_DIR) if os.path.isdir(os.path.join(IMAGE_DIR, d))]
# Return the resolutions as JSON
response = {
'resolutions': resolutions
}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps(response).encode())
except Exception as e:
self.send_error(500, f"Error serving resolutions: {str(e)}")
def serve_selections(self):
print("DEBUG: serve_selections() called")
try:
# Get all selections from the database
selections = get_selections()
# Return the selections as JSON
response = {
'selections': selections
}
# Debug the response before sending
print(f"DEBUG: Response has {len(selections)} selections")
# Try to serialize to JSON to catch any serialization errors
try:
response_json = json.dumps(response)
print(f"DEBUG: JSON serialization successful, length: {len(response_json)}")
except Exception as json_err:
print(f"DEBUG ERROR in JSON serialization: {str(json_err)}")
# If there's an error in serialization, send a simpler response
response = {'selections': [], 'error': 'JSON serialization error'}
response_json = json.dumps(response)
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(response_json.encode())
print("DEBUG: Response sent successfully")
except Exception as e:
print(f"DEBUG ERROR in serve_selections(): {str(e)}")
self.send_error(500, f"Error serving selections: {str(e)}")
def do_POST(self):
# Parse the URL path
parsed_path = urllib.parse.urlparse(self.path)
path = parsed_path.path
if path == "/selection":
self.handle_selection()
elif path == "/record-selection":
self.handle_record_selection()
elif path == "/update-selection":
self.handle_update_selection()
elif path == "/delete-selection":
self.handle_delete_selection()
elif self.path == '/reset-database':
self.handle_reset_database()
return
elif self.path.startswith('/download-selected'):
self.handle_download_selected()
return
else:
self.send_error(404, "Not found")
def handle_selection(self):
"""Handle legacy /selection POST with image_path and action"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
data = json.loads(post_data)
image_path = data.get('image_path')
action = data.get('action')
if not image_path or not action:
self.send_error(400, "Missing required fields")
return
add_selection(image_path, action)
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps({'status': 'success'}).encode())
except Exception as e:
print(f"ERROR in handle_selection: {e}")
self.send_error(500, f"Server error: {e}")
def handle_record_selection(self):
try:
# Get the content length
content_length = int(self.headers['Content-Length'])
# Read the request body
post_data = self.rfile.read(content_length).decode('utf-8')
data = json.loads(post_data)
# Extract the required fields
image_path = data.get('path', '').replace('/images/', '')
resolution = data.get('resolution', '')
action = data.get('action', '')
# Validate the data
if not image_path or not action:
self.send_error(400, "Missing required fields")
return
# Store only image_path & action for compatibility
add_selection(image_path, action)
# Return success response
response = {
'success': True,
'message': f"Selection recorded: {action} for {image_path}"
}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps(response).encode())
except Exception as e:
self.send_error(500, f"Error recording selection: {str(e)}")
def handle_update_selection(self):
try:
# Get the content length
content_length = int(self.headers['Content-Length'])
# Read the request body
post_data = self.rfile.read(content_length).decode('utf-8')
data = json.loads(post_data)
# Extract the required fields
selection_id = data.get('id')
action = data.get('action', '')
# Validate the data
if not selection_id or not action:
self.send_error(400, "Missing required fields")
return
# Update the selection in the database
success = update_selection(selection_id, action)
if not success:
self.send_error(404, f"Selection with ID {selection_id} not found")
return
# Return success response
response = {
'success': True,
'message': f"Selection updated: ID {selection_id} to {action}"
}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps(response).encode())
except Exception as e:
self.send_error(500, f"Error updating selection: {str(e)}")
def handle_delete_selection(self):
try:
# Get the content length
content_length = int(self.headers['Content-Length'])
# Read the request body
post_data = self.rfile.read(content_length).decode('utf-8')
data = json.loads(post_data)
# Extract the required fields
selection_id = data.get('id')
# Validate the data
if not selection_id:
self.send_error(400, "Missing selection ID")
return
# Delete the selection from the database
success = delete_selection(selection_id)
if not success:
self.send_error(404, f"Selection with ID {selection_id} not found")
return
# Return success response
response = {
'success': True,
'message': f"Selection deleted: ID {selection_id}"
}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._set_cors_headers()
self.end_headers()
self.wfile.write(json.dumps(response).encode())
except Exception as e:
print(f"DEBUG ERROR in handle_delete_selection(): {str(e)}")
self.send_error(500, f"Error deleting selection: {str(e)}")
def handle_reset_database(self):
try:
reset_database()
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(json.dumps({'success': True, 'message': 'Database reset successfully'}).encode())
except Exception as e:
print(f"DEBUG ERROR in handle_reset_database: {str(e)}")
self.send_response(500)
self.send_header('Content-type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(json.dumps({'success': False, 'message': f'Error: {str(e)}'}).encode())
def handle_download_selected(self):
try:
# Parse the query parameters to get the selected image paths
query_components = urllib.parse.parse_qs(urllib.parse.urlparse(self.path).query)
image_paths = query_components.get('paths', [])
if not image_paths:
self.send_response(400)
self.send_header('Content-type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(json.dumps({'success': False, 'message': 'No image paths provided'}).encode())
return
# Create a zip file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip_file:
for path in image_paths:
# Remove the /images/ prefix
if path.startswith('/images/'):
path = path[8:]
full_path = os.path.join(IMAGE_DIR, path)
if os.path.exists(full_path):
# Add the file to the zip with just the filename (no directory structure)
filename = os.path.basename(path)
zip_file.write(full_path, filename)
# Seek to the beginning of the buffer
zip_buffer.seek(0)
# Send the zip file as a response
self.send_response(200)
self.send_header('Content-type', 'application/zip')
self.send_header('Content-Disposition', 'attachment; filename="selected_images.zip"')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(zip_buffer.getvalue())
except Exception as e:
print(f"DEBUG ERROR in handle_download_selected: {str(e)}")
self.send_response(500)
self.send_header('Content-type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(json.dumps({'success': False, 'message': f'Error: {str(e)}'}).encode())
def do_OPTIONS(self):
# Handle preflight requests for CORS
self.send_response(200)
self._set_cors_headers()
self.end_headers()
def run(server_class=HTTPServer, handler_class=ImageSwipeHandler, port=8000):
# Initialize the database
init_db()
# Ensure the 'images' directory exists
if not os.path.exists(IMAGE_DIR):
os.makedirs(IMAGE_DIR)
# Sync the image database on startup
sync_image_database()
server_address = ('', port)
httpd = server_class(server_address, handler_class)
print(f"Starting server on port {port}...")
print(f"Image directory: {IMAGE_DIR}")
print(f"Database: {DB_PATH}")
httpd.serve_forever()
if __name__ == "__main__":
run()