181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
||
"""Load datasets/inventory.csv into the **inventory** table, replacing any
|
||
existing contents.
|
||
|
||
Usage:
|
||
python load_inventory_to_db.py [CSV_PATH]
|
||
|
||
If ``CSV_PATH`` is omitted the script defaults to ``datasets/inventory.csv``
|
||
relative to the project root.
|
||
|
||
This script is similar in style to the other ETL helpers in ``scripts/``. It is
|
||
idempotent – it truncates the ``inventory`` table before bulk-inserting the new
|
||
rows.
|
||
|
||
The database connection details are read from the standard ``db.conf`` file
|
||
located at the project root. The file must define at least the following keys::
|
||
|
||
PSQL_HOST
|
||
PSQL_PORT
|
||
PSQL_USER
|
||
PSQL_PASSWORD
|
||
PSQL_DBNAME
|
||
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import asyncio
|
||
import csv
|
||
import datetime as _dt
|
||
import pathlib
|
||
import re
|
||
from typing import Dict, List, Tuple
|
||
|
||
import asyncpg
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Paths & Constants
|
||
# ---------------------------------------------------------------------------
|
||
PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||
CONF_PATH = PROJECT_ROOT / "db.conf"
|
||
DEFAULT_CSV_PATH = PROJECT_ROOT / "datasets" / "inventory.csv"
|
||
|
||
RE_CONF = re.compile(r"^([A-Z0-9_]+)=(.*)$")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def parse_db_conf(path: pathlib.Path) -> Dict[str, str]:
|
||
"""Parse ``db.conf`` (simple KEY=VALUE format) into a dict."""
|
||
if not path.exists():
|
||
raise FileNotFoundError("db.conf not found at project root – required for DB credentials")
|
||
|
||
conf: Dict[str, str] = {}
|
||
for line in path.read_text().splitlines():
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
if (m := RE_CONF.match(line)):
|
||
key, value = m.group(1), m.group(2).strip().strip("'\"")
|
||
conf[key] = value
|
||
|
||
required = {"PSQL_HOST", "PSQL_PORT", "PSQL_USER", "PSQL_PASSWORD", "PSQL_DBNAME"}
|
||
missing = required - conf.keys()
|
||
if missing:
|
||
raise RuntimeError(f"Missing keys in db.conf: {', '.join(sorted(missing))}")
|
||
|
||
return conf
|
||
|
||
|
||
async def ensure_inventory_table(conn: asyncpg.Connection) -> None:
|
||
"""Create the ``inventory`` table if it doesn't already exist.
|
||
|
||
The schema mirrors the SQLAlchemy model in ``backend/app/models.py``.
|
||
"""
|
||
await conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS inventory (
|
||
id SERIAL PRIMARY KEY,
|
||
character_name TEXT NOT NULL,
|
||
storage_type TEXT NOT NULL,
|
||
item_name TEXT NOT NULL,
|
||
quantity INT NOT NULL,
|
||
item_id INT,
|
||
last_updated TIMESTAMPTZ DEFAULT NOW()
|
||
);
|
||
"""
|
||
)
|
||
|
||
|
||
async def truncate_inventory(conn: asyncpg.Connection) -> None:
|
||
"""Remove all rows from the inventory table before re-inserting."""
|
||
await conn.execute("TRUNCATE TABLE inventory;")
|
||
|
||
|
||
async def fetch_item_ids(conn: asyncpg.Connection, item_names: List[str]) -> Dict[str, int]:
|
||
"""Fetch item IDs from the database."""
|
||
rows = await conn.fetch("SELECT id, name FROM all_items WHERE name = ANY($1::text[])", item_names)
|
||
return {row["name"]: row["id"] for row in rows}
|
||
|
||
|
||
async def copy_csv_to_db(conn: asyncpg.Connection, rows: List[Tuple[str, str, str, int, int, _dt.datetime]]) -> None:
|
||
"""Bulk copy the parsed CSV rows into the DB using ``copy_records_to_table``."""
|
||
await conn.copy_records_to_table(
|
||
"inventory",
|
||
records=rows,
|
||
columns=[
|
||
"character_name",
|
||
"storage_type",
|
||
"item_name",
|
||
"item_id",
|
||
"quantity",
|
||
"last_updated",
|
||
],
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main logic
|
||
# ---------------------------------------------------------------------------
|
||
|
||
async def load_inventory(csv_path: pathlib.Path) -> None:
|
||
if not csv_path.exists():
|
||
raise SystemExit(f"CSV file not found: {csv_path}")
|
||
|
||
conf = parse_db_conf(CONF_PATH)
|
||
|
||
conn = await asyncpg.connect(
|
||
host=conf["PSQL_HOST"],
|
||
port=int(conf["PSQL_PORT"]),
|
||
user=conf["PSQL_USER"],
|
||
password=conf["PSQL_PASSWORD"],
|
||
database=conf["PSQL_DBNAME"],
|
||
)
|
||
try:
|
||
await ensure_inventory_table(conn)
|
||
await truncate_inventory(conn)
|
||
|
||
# Parse CSV
|
||
rows: List[Tuple[str, str, str, int, int]] = []
|
||
with csv_path.open(newline="", encoding="utf-8") as f:
|
||
reader = csv.DictReader(f, delimiter=";", quotechar='"')
|
||
names_set = set()
|
||
for r in reader:
|
||
names_set.add(r["item"].strip())
|
||
# fetch ids
|
||
id_rows = await conn.fetch("SELECT id,name FROM all_items WHERE name = ANY($1::text[])", list(names_set))
|
||
id_map = {row["name"]: row["id"] for row in id_rows}
|
||
f.seek(0)
|
||
next(reader) # skip header again
|
||
for r in reader:
|
||
char = r["char"].strip()
|
||
storage = r["storage"].strip()
|
||
item = r["item"].strip()
|
||
qty = int(r["quantity"].strip()) if r["quantity"].strip() else 0
|
||
item_id = id_map.get(item)
|
||
rows.append((char, storage, item, item_id, qty, _dt.datetime.utcnow()))
|
||
|
||
await copy_csv_to_db(conn, rows)
|
||
print(f"Inserted {len(rows)} inventory rows.")
|
||
finally:
|
||
await conn.close()
|
||
|
||
|
||
async def main_async(csv_arg: str | None) -> None:
|
||
csv_path = pathlib.Path(csv_arg).expanduser().resolve() if csv_arg else DEFAULT_CSV_PATH
|
||
await load_inventory(csv_path)
|
||
|
||
|
||
def main() -> None:
|
||
p = argparse.ArgumentParser(description="Load inventory CSV into DB")
|
||
p.add_argument("csv", nargs="?", help="Path to CSV; defaults to datasets/inventory.csv")
|
||
args = p.parse_args()
|
||
|
||
asyncio.run(main_async(args.csv))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|