Browse Source

enforce strict typing

master
Silberengel 2 weeks ago
parent
commit
7f5379a655
  1. 9
      pyproject.toml
  2. 10
      src/imwald/core/accounts_store.py
  3. 2
      src/imwald/core/author_html.py
  4. 45
      src/imwald/core/database.py
  5. 7
      src/imwald/core/kind0_profile.py
  6. 14
      src/imwald/core/md_render.py
  7. 7
      src/imwald/core/nostr_crypto.py
  8. 15
      src/imwald/core/nostr_engine.py
  9. 17
      src/imwald/core/nostr_nip96_upload.py
  10. 7
      src/imwald/core/nostr_publish.py
  11. 6
      src/imwald/core/nostr_types.py
  12. 5
      src/imwald/core/ranker.py
  13. 33
      src/imwald/core/relay_list.py
  14. 51
      src/imwald/core/relay_manager.py
  15. 12
      src/imwald/ui/composer_dialog.py
  16. 18
      src/imwald/ui/db_admin_page.py
  17. 27
      src/imwald/ui/feed_page.py
  18. 18
      src/imwald/ui/main_window.py
  19. 4
      src/imwald/ui/notifications_page.py
  20. 38
      src/imwald/ui/onboarding_wizard.py
  21. 2
      src/imwald/ui/search_page.py
  22. 19
      tests/test_kind30000_lists.py
  23. 5
      typings/quickjs.pyi

9
pyproject.toml

@ -47,8 +47,7 @@ pythonVersion = "3.11"
# So third-party stubs (e.g. Pillow → ``PIL``) resolve when using ``.venv`` at the repo root. # So third-party stubs (e.g. Pillow → ``PIL``) resolve when using ``.venv`` at the repo root.
venvPath = "." venvPath = "."
venv = ".venv" venv = ".venv"
# Desktop app + sqlite/Qt stubs surface a lot of ``Any``; keep checks useful without IDE noise. # Strict static typing; partial stubs live under ``typings/`` (``stubPath``).
typeCheckingMode = "standard" typeCheckingMode = "strict"
reportMissingTypeStubs = "none" reportMissingTypeStubs = "error"
reportAny = "none" stubPath = "typings"
reportExplicitAny = "none"

10
src/imwald/core/accounts_store.py

@ -3,9 +3,9 @@
from __future__ import annotations from __future__ import annotations
import json import json
from dataclasses import asdict, dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, cast
from coincurve import PrivateKey from coincurve import PrivateKey
@ -52,7 +52,11 @@ def load_accounts(path: Path | None = None) -> list[StoredAccount]:
data = json.loads(p.read_text(encoding="utf-8")) data = json.loads(p.read_text(encoding="utf-8"))
if not isinstance(data, list): if not isinstance(data, list):
return [] return []
return [StoredAccount.from_json(x) for x in data if isinstance(x, dict)] out_acct: list[StoredAccount] = []
for x in cast(list[object], data):
if isinstance(x, dict):
out_acct.append(StoredAccount.from_json(cast(dict[str, Any], x)))
return out_acct
def save_accounts(accounts: list[StoredAccount], path: Path | None = None) -> None: def save_accounts(accounts: list[StoredAccount], path: Path | None = None) -> None:

2
src/imwald/core/author_html.py

@ -8,7 +8,7 @@ from imwald.core.kind0_profile import display_name_from_profile
def safe_http_url(u: str | None) -> str | None: def safe_http_url(u: str | None) -> str | None:
if not u or not isinstance(u, str): if not u:
return None return None
u = u.strip() u = u.strip()
if u.startswith("https://") or u.startswith("http://"): if u.startswith("https://") or u.startswith("http://"):

45
src/imwald/core/database.py

@ -7,7 +7,8 @@ import sqlite3
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Generator, Iterable, TypedDict, cast from collections.abc import Generator, Iterable
from typing import Any, TypedDict, cast
SCHEMA_VERSION = 2 SCHEMA_VERSION = 2
@ -164,7 +165,7 @@ CREATE INDEX IF NOT EXISTS idx_feed_views_event ON feed_views(event_id);
class Database: class Database:
def __init__(self, path: Path) -> None: def __init__(self, path: Path) -> None:
self.path = path self.path: Path = path
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
self._conn: sqlite3.Connection | None = None self._conn: sqlite3.Connection | None = None
@ -233,7 +234,12 @@ class Database:
) -> None: ) -> None:
"""Insert or replace event; expand tags into tags table.""" """Insert or replace event; expand tags into tags table."""
eid = ev["id"] eid = ev["id"]
tags = ev.get("tags") or [] raw_tags = ev.get("tags")
tags: list[list[str]] = (
cast(list[list[str]], raw_tags)
if isinstance(raw_tags, list)
else []
)
tags_json = json.dumps(tags, ensure_ascii=False) tags_json = json.dumps(tags, ensure_ascii=False)
raw = json.dumps(ev, ensure_ascii=False) raw = json.dumps(ev, ensure_ascii=False)
with self.write_lock() as c: with self.write_lock() as c:
@ -314,7 +320,7 @@ class Database:
try: try:
ev = json.loads(raw) ev = json.loads(raw)
if isinstance(ev, dict): if isinstance(ev, dict):
return ev return cast(dict[str, Any], ev)
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
return { return {
@ -324,7 +330,7 @@ class Database:
"kind": row["kind"], "kind": row["kind"],
"content": row["content"] or "", "content": row["content"] or "",
"sig": row["sig"], "sig": row["sig"],
"tags": json.loads(row["tags_json"] or "[]"), "tags": cast(list[list[str]], json.loads(row["tags_json"] or "[]")),
} }
def get_event(self, event_id: str) -> StoredEventRow | None: def get_event(self, event_id: str) -> StoredEventRow | None:
@ -411,7 +417,7 @@ class Database:
"kind": row["kind"], "kind": row["kind"],
"content": row["content"], "content": row["content"],
"sig": row["sig"], "sig": row["sig"],
"tags": json.loads(row["tags_json"] or "[]"), "tags": cast(list[list[str]], json.loads(row["tags_json"] or "[]")),
"source_relay": row["source_relay"], "source_relay": row["source_relay"],
} }
) )
@ -453,20 +459,25 @@ class Database:
try: try:
data = json.loads(content) data = json.loads(content)
if isinstance(data, list): if isinstance(data, list):
for x in data: for x in cast(list[object], data):
if isinstance(x, str) and len(x) == 64: if isinstance(x, str) and len(x) == 64:
out.add(x.lower()) out.add(x.lower())
elif isinstance(x, dict) and "pubkey" in x: elif isinstance(x, dict) and "pubkey" in x:
pk = str(x["pubkey"]) xd = cast(dict[str, object], x)
pk = str(xd.get("pubkey", ""))
if len(pk) == 64: if len(pk) == 64:
out.add(pk.lower()) out.add(pk.lower())
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
try: try:
tags = json.loads(row["tags_json"] or "[]") tags_raw = json.loads(row["tags_json"] or "[]")
for t in tags: if isinstance(tags_raw, list):
if t and t[0] == "p" and len(t) > 1 and len(t[1]) == 64: for t_obj in cast(list[object], tags_raw):
out.add(str(t[1]).lower()) if not isinstance(t_obj, list) or not t_obj:
continue
row = cast(list[object], t_obj)
if str(row[0]) == "p" and len(row) > 1 and len(str(row[1])) == 64:
out.add(str(row[1]).lower())
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
return out return out
@ -512,7 +523,7 @@ class Database:
"kind": r["kind"], "kind": r["kind"],
"content": r["content"], "content": r["content"],
"sig": r["sig"], "sig": r["sig"],
"tags": json.loads(r["tags_json"] or "[]"), "tags": cast(list[list[str]], json.loads(r["tags_json"] or "[]")),
} }
for r in cur for r in cur
] ]
@ -527,7 +538,7 @@ class Database:
""", """,
(q, q, q, limit), (q, q, q, limit),
) )
rows = [] rows: list[dict[str, Any]] = []
for row in cur: for row in cur:
rows.append( rows.append(
{ {
@ -537,7 +548,7 @@ class Database:
"kind": row["kind"], "kind": row["kind"],
"content": row["content"], "content": row["content"],
"sig": row["sig"], "sig": row["sig"],
"tags": json.loads(row["tags_json"] or "[]"), "tags": cast(list[list[str]], json.loads(row["tags_json"] or "[]")),
} }
) )
return rows return rows
@ -558,7 +569,7 @@ class Database:
"kind": r["kind"], "kind": r["kind"],
"content": r["content"], "content": r["content"],
"sig": r["sig"], "sig": r["sig"],
"tags": json.loads(r["tags_json"] or "[]"), "tags": cast(list[list[str]], json.loads(r["tags_json"] or "[]")),
} }
for r in cur for r in cur
] ]
@ -599,7 +610,7 @@ class Database:
def get_latest_kind0_profiles(self, pubkeys: Iterable[str]) -> dict[str, Kind0ProfileSummary]: def get_latest_kind0_profiles(self, pubkeys: Iterable[str]) -> dict[str, Kind0ProfileSummary]:
"""Most recent kind-0 ``content`` per pubkey (lowercase hex keys).""" """Most recent kind-0 ``content`` per pubkey (lowercase hex keys)."""
pks = [p.lower() for p in pubkeys if isinstance(p, str) and len(p) == 64] pks = [p.lower() for p in pubkeys if len(p) == 64]
if not pks: if not pks:
return {} return {}
placeholders = ",".join("?" * len(pks)) placeholders = ",".join("?" * len(pks))

7
src/imwald/core/kind0_profile.py

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any from typing import cast
def parse_kind0_profile(content: str) -> dict[str, str | None]: def parse_kind0_profile(content: str) -> dict[str, str | None]:
@ -17,11 +17,12 @@ def parse_kind0_profile(content: str) -> dict[str, str | None]:
"banner": None, "banner": None,
} }
try: try:
d: Any = json.loads(content or "") raw = json.loads(content or "")
except json.JSONDecodeError: except json.JSONDecodeError:
return empty return empty
if not isinstance(d, dict): if not isinstance(raw, dict):
return empty return empty
d = cast(dict[str, object], raw)
def pick(*keys: str) -> str | None: def pick(*keys: str) -> str | None:
for k in keys: for k in keys:

14
src/imwald/core/md_render.py

@ -17,6 +17,7 @@ from imwald.core.nostr_entity_render import preprocess_nostr_entities
if TYPE_CHECKING: if TYPE_CHECKING:
from imwald.core.database import Database from imwald.core.database import Database
from quickjs import Context
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -27,7 +28,7 @@ _STANDALONE_IMAGE_URL = re.compile(
) )
_MARKED_PATH = Path(__file__).resolve().parents[1] / "ui" / "assets" / "vendor" / "marked.min.js" _MARKED_PATH = Path(__file__).resolve().parents[1] / "ui" / "assets" / "vendor" / "marked.min.js"
_qjs_ctx = None _qjs_ctx: Context | None = None
_marked_load_failed = False _marked_load_failed = False
_nh3_attrs_merged: dict[str, set[str]] | None = None _nh3_attrs_merged: dict[str, set[str]] | None = None
@ -67,11 +68,10 @@ def _nh3_attributes() -> dict[str, set[str]]:
if _nh3_attrs_merged is None: if _nh3_attrs_merged is None:
raw = cast(MutableMapping[str, set[str]], deepcopy(nh3.ALLOWED_ATTRIBUTES)) raw = cast(MutableMapping[str, set[str]], deepcopy(nh3.ALLOWED_ATTRIBUTES))
for tag in ("span", "div"): for tag in ("span", "div"):
s = raw.get(tag) cur = raw.get(tag)
if s is None: tag_set: set[str] = set(cur) if cur is not None else set()
s = set() raw[tag] = tag_set
raw[tag] = s tag_set.update({"class", "style", "title"})
s.update({"class", "style", "title"})
img_a = raw.get("img") img_a = raw.get("img")
if img_a is not None: if img_a is not None:
img_a.add("style") img_a.add("style")
@ -87,7 +87,7 @@ def _nh3_clean(html: str) -> str:
) )
def _marked_quickjs_ctx(): def _marked_quickjs_ctx() -> Context | None:
"""Singleton QuickJS context with ``marked`` loaded, or None if unavailable.""" """Singleton QuickJS context with ``marked`` loaded, or None if unavailable."""
global _qjs_ctx, _marked_load_failed global _qjs_ctx, _marked_load_failed
if _marked_load_failed: if _marked_load_failed:

7
src/imwald/core/nostr_crypto.py

@ -4,7 +4,7 @@ from __future__ import annotations
import json import json
from hashlib import sha256 from hashlib import sha256
from typing import Any from typing import Any, cast
from coincurve import PrivateKey from coincurve import PrivateKey
from coincurve.keys import PublicKeyXOnly from coincurve.keys import PublicKeyXOnly
@ -43,9 +43,10 @@ def verify_nostr_event(ev: dict[str, Any]) -> bool:
required = ("id", "pubkey", "created_at", "kind", "tags", "content", "sig") required = ("id", "pubkey", "created_at", "kind", "tags", "content", "sig")
if not all(k in ev for k in required): if not all(k in ev for k in required):
return False return False
tags = ev["tags"] tags_raw = ev["tags"]
if not isinstance(tags, list): if not isinstance(tags_raw, list):
return False return False
tags = cast(list[list[str]], tags_raw)
pk_hex = str(ev["pubkey"]).lower() pk_hex = str(ev["pubkey"]).lower()
if len(pk_hex) != 64 or any(c not in "0123456789abcdef" for c in pk_hex): if len(pk_hex) != 64 or any(c not in "0123456789abcdef" for c in pk_hex):
return False return False

15
src/imwald/core/nostr_engine.py

@ -7,7 +7,7 @@ import json
import logging import logging
import threading import threading
import time import time
from typing import Any from typing import Any, cast
from PySide6.QtCore import QObject, Signal from PySide6.QtCore import QObject, Signal
@ -123,14 +123,19 @@ class NostrEngine(QObject):
@staticmethod @staticmethod
def apply_ingest_to_db(db: Database, ev: dict[str, Any], source_relay: str | None = None) -> None: def apply_ingest_to_db(db: Database, ev: dict[str, Any], source_relay: str | None = None) -> None:
if not isinstance(ev, dict) or "id" not in ev: if "id" not in ev:
return return
if not verify_nostr_event(ev): if not verify_nostr_event(ev):
return return
if ev.get("kind") == 5: if ev.get("kind") == 5:
for t in ev.get("tags") or []: raw_tags = ev.get("tags")
if t and t[0] == "e" and len(t) > 1: tag_rows: list[object] = cast(list[object], raw_tags) if isinstance(raw_tags, list) else []
db.tombstone_event(t[1]) for t_obj in tag_rows:
if not isinstance(t_obj, list):
continue
t = cast(list[object], t_obj)
if t and str(t[0]) == "e" and len(t) > 1:
db.tombstone_event(str(t[1]))
db.upsert_event(ev, source_relay=source_relay) db.upsert_event(ev, source_relay=source_relay)
def publish_kind0_and_lists( def publish_kind0_and_lists(

17
src/imwald/core/nostr_nip96_upload.py

@ -10,7 +10,7 @@ import time
import urllib.error import urllib.error
import urllib.request import urllib.request
from hashlib import sha256 from hashlib import sha256
from typing import Any from typing import Any, cast
from imwald.core.nostr_crypto import build_signed_event from imwald.core.nostr_crypto import build_signed_event
@ -126,11 +126,18 @@ def upload_image_nip96_nostr_build(
data: dict[str, Any] = json.loads(raw) data: dict[str, Any] = json.loads(raw)
if data.get("status") != "success": if data.get("status") != "success":
raise RuntimeError(data.get("message") or "nostr.build upload unsuccessful") raise RuntimeError(data.get("message") or "nostr.build upload unsuccessful")
nip94 = data.get("nip94_event") or {} nip94_raw: object = data.get("nip94_event") or {}
tags = nip94.get("tags") or [] if not isinstance(nip94_raw, dict):
if not isinstance(tags, list): nip94_raw = {}
nip94 = cast(dict[str, Any], nip94_raw)
tags_raw: object = nip94.get("tags") or []
if not isinstance(tags_raw, list):
raise RuntimeError("invalid nip94_event.tags in upload response") raise RuntimeError("invalid nip94_event.tags in upload response")
url = next((str(t[1]) for t in tags if isinstance(t, list) and len(t) >= 2 and str(t[0]) == "url"), None) tags: list[list[str]] = []
for item in cast(list[object], tags_raw):
if isinstance(item, list):
tags.append([str(x) for x in cast(list[object], item)])
url = next((row[1] for row in tags if len(row) >= 2 and row[0] == "url"), None)
if not url: if not url:
raise RuntimeError("no url tag in nip94_event response") raise RuntimeError("no url tag in nip94_event response")
return url, tags return url, tags

7
src/imwald/core/nostr_publish.py

@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
from typing import Any from typing import Any, cast
import websockets import websockets
@ -29,9 +29,10 @@ async def publish_to_relays(urls: list[str], event: dict[str, Any], timeout: flo
async with websockets.connect(ws_url, ping_interval=20, open_timeout=timeout) as ws: async with websockets.connect(ws_url, ping_interval=20, open_timeout=timeout) as ws:
await ws.send(json.dumps(["EVENT", event])) await ws.send(json.dumps(["EVENT", event]))
raw = await asyncio.wait_for(ws.recv(), timeout=timeout) raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
msg = json.loads(raw) msg: object = json.loads(raw)
# NIP-01: ["OK", <event_id>, <bool>, <message optional>] # NIP-01: ["OK", <event_id>, <bool>, <message optional>]
ok = isinstance(msg, list) and len(msg) >= 3 and msg[0] == "OK" and msg[2] is True wire = cast(list[object], msg) if isinstance(msg, list) else []
ok = len(wire) >= 3 and wire[0] == "OK" and wire[2] is True
results[url] = ok results[url] = ok
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001
log.info("publish fail %s: %s", url, e) log.info("publish fail %s: %s", url, e)

6
src/imwald/core/nostr_types.py

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Any from typing import Any, cast
@dataclass @dataclass
@ -18,7 +18,7 @@ class NostrEvent:
def from_row(cls, row: dict[str, Any]) -> NostrEvent: def from_row(cls, row: dict[str, Any]) -> NostrEvent:
import json import json
tags = json.loads(row["tags_json"] or "[]") tags = cast(list[list[str]], json.loads(row["tags_json"] or "[]"))
return cls( return cls(
id=row["id"], id=row["id"],
pubkey=row["pubkey"], pubkey=row["pubkey"],

5
src/imwald/core/ranker.py

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any from typing import Any, cast
from imwald.core.relay_policy import is_wisp_trending_relay_url from imwald.core.relay_policy import is_wisp_trending_relay_url
@ -66,7 +66,8 @@ class Ranker:
if sr and is_wisp_trending_relay_url(sr): if sr and is_wisp_trending_relay_url(sr):
score += WEIGHT_TRENDING_RELAY score += WEIGHT_TRENDING_RELAY
why["trending_relay"] = WEIGHT_TRENDING_RELAY why["trending_relay"] = WEIGHT_TRENDING_RELAY
tags = ev.get("tags") or [] raw_tags: object = ev.get("tags") or []
tags: list[list[str]] = cast(list[list[str]], raw_tags) if isinstance(raw_tags, list) else []
if _tags_contain_repost(tags): if _tags_contain_repost(tags):
score += WEIGHT_BOOST score += WEIGHT_BOOST
why["repost_or_quote_hint"] = WEIGHT_BOOST why["repost_or_quote_hint"] = WEIGHT_BOOST

33
src/imwald/core/relay_list.py

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import cast
from imwald.core.database import Database from imwald.core.database import Database
from imwald.core.relay_policy import DEFAULT_READ_RELAYS, DEFAULT_WRITE_RELAYS from imwald.core.relay_policy import DEFAULT_READ_RELAYS, DEFAULT_WRITE_RELAYS
@ -26,7 +26,7 @@ def _dedupe_preserve(urls: list[str]) -> list[str]:
return out return out
def parse_kind10002_tags(tags: Any) -> tuple[list[str], list[str]]: def parse_kind10002_tags(tags: object) -> tuple[list[str], list[str]]:
""" """
Parse NIP-65 `r` tags into (read_urls, write_urls). Parse NIP-65 `r` tags into (read_urls, write_urls).
- ``["r", url]`` or unknown third value both read and write. - ``["r", url]`` or unknown third value both read and write.
@ -35,34 +35,39 @@ def parse_kind10002_tags(tags: Any) -> tuple[list[str], list[str]]:
""" """
if not isinstance(tags, list): if not isinstance(tags, list):
return [], [] return [], []
wire = cast(list[object], tags)
read: list[str] = [] read: list[str] = []
write: list[str] = [] write: list[str] = []
i = 0 i = 0
while i < len(tags): while i < len(wire):
t = tags[i] t = wire[i]
i += 1 i += 1
if not t or not isinstance(t, list) or len(t) < 2: if not isinstance(t, list):
continue continue
if str(t[0]) != "r": row = cast(list[object], t)
if len(row) < 2:
continue continue
url = str(t[1]).strip() if str(row[0]) != "r":
continue
url = str(row[1]).strip()
if not _is_ws_relay_url(url): if not _is_ws_relay_url(url):
continue continue
mode = "both" mode = "both"
if len(t) >= 3: if len(row) >= 3:
m = str(t[2]).lower() m = str(row[2]).lower()
if m == "read": if m == "read":
mode = "read" mode = "read"
elif m == "write": elif m == "write":
mode = "write" mode = "write"
else: else:
if i < len(tags): if i < len(wire):
nxt = tags[i] nxt = wire[i]
if nxt and isinstance(nxt, list) and len(nxt) >= 2: if isinstance(nxt, list) and len(cast(list[object], nxt)) >= 2:
name = str(nxt[0]).lower() nxt_row = cast(list[object], nxt)
val = str(nxt[1]).lower() name = str(nxt_row[0]).lower()
val = str(nxt_row[1]).lower()
if name == "read" and val in ("true", "1", "yes"): if name == "read" and val in ("true", "1", "yes"):
mode = "read" mode = "read"
i += 1 i += 1

51
src/imwald/core/relay_manager.py

@ -10,7 +10,7 @@ import random
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine, cast
import websockets import websockets
from websockets.asyncio.client import ClientConnection from websockets.asyncio.client import ClientConnection
@ -33,8 +33,8 @@ class RelayConn:
last_error: str | None = None last_error: str | None = None
last_connected_at: float | None = None last_connected_at: float | None = None
backoff_until: float = 0.0 backoff_until: float = 0.0
_ws: ClientConnection | None = field(default=None, repr=False) ws: ClientConnection | None = field(default=None, repr=False)
_task: asyncio.Task[None] | None = field(default=None, repr=False) runner_task: asyncio.Task[None] | None = field(default=None, repr=False)
def status_line(self) -> str: def status_line(self) -> str:
err = f" ({self.last_error})" if self.last_error else "" err = f" ({self.last_error})" if self.last_error else ""
@ -80,38 +80,38 @@ class RelayManager:
async def stop(self) -> None: async def stop(self) -> None:
self._shutdown.set() self._shutdown.set()
for r in self._relays.values(): for r in self._relays.values():
if r._task: if r.runner_task:
r._task.cancel() r.runner_task.cancel()
with contextlib.suppress(asyncio.CancelledError): with contextlib.suppress(asyncio.CancelledError):
await r._task await r.runner_task
r._task = None r.runner_task = None
if r._ws: if r.ws:
await r._ws.close() await r.ws.close()
r._ws = None r.ws = None
def request_subscribe(self, relay_url: str, sub_id: str, filters: list[dict[str, Any]]) -> None: def request_subscribe(self, relay_url: str, sub_id: str, filters: list[dict[str, Any]]) -> None:
relay_url = _normalize_ws_url(relay_url) relay_url = _normalize_ws_url(relay_url)
self._subs[f"{relay_url}:{sub_id}"] = {"relay": relay_url, "sub_id": sub_id, "filters": filters} self._subs[f"{relay_url}:{sub_id}"] = {"relay": relay_url, "sub_id": sub_id, "filters": filters}
if relay_url in self._relays and self._relays[relay_url]._ws: if relay_url in self._relays and self._relays[relay_url].ws:
asyncio.create_task(self._send_req(relay_url, sub_id, filters)) asyncio.create_task(self._send_req(relay_url, sub_id, filters))
async def _send_req(self, relay_url: str, sub_id: str, filters: list[dict[str, Any]]) -> None: async def _send_req(self, relay_url: str, sub_id: str, filters: list[dict[str, Any]]) -> None:
r = self._relays.get(relay_url) r = self._relays.get(relay_url)
if not r or not r._ws: if not r or not r.ws:
return return
msg = json.dumps(["REQ", sub_id, *filters]) msg = json.dumps(["REQ", sub_id, *filters])
await r._ws.send(msg) await r.ws.send(msg)
async def _ensure_connected(self, url: str) -> None: async def _ensure_connected(self, url: str) -> None:
r = self._relays[url] r = self._relays[url]
now = time.monotonic() now = time.monotonic()
if r.state == RelayState.BACKOFF and now < r.backoff_until: if r.state == RelayState.BACKOFF and now < r.backoff_until:
return return
if r._ws and r.state == RelayState.CONNECTED: if r.ws and r.state == RelayState.CONNECTED:
return return
if r._task and not r._task.done(): if r.runner_task and not r.runner_task.done():
return return
r._task = asyncio.create_task(self._run_relay(url)) r.runner_task = asyncio.create_task(self._run_relay(url))
async def _run_relay(self, url: str) -> None: async def _run_relay(self, url: str) -> None:
r = self._relays[url] r = self._relays[url]
@ -128,11 +128,11 @@ class RelayManager:
close_timeout=5, close_timeout=5,
max_size=2**22, max_size=2**22,
) as ws: ) as ws:
r._ws = ws r.ws = ws
r.state = RelayState.CONNECTED r.state = RelayState.CONNECTED
r.last_connected_at = time.time() r.last_connected_at = time.time()
# re-send subscriptions for this relay # re-send subscriptions for this relay
for key, sub in self._subs.items(): for _, sub in self._subs.items():
if sub["relay"] == url: if sub["relay"] == url:
await self._send_req(url, sub["sub_id"], sub["filters"]) await self._send_req(url, sub["sub_id"], sub["filters"])
attempt = 0 attempt = 0
@ -140,16 +140,17 @@ class RelayManager:
if self._shutdown.is_set(): if self._shutdown.is_set():
break break
try: try:
msg = json.loads(raw) msg: object = json.loads(raw)
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
if not isinstance(msg, list) or not msg: if not isinstance(msg, list) or not msg:
continue continue
typ = msg[0] wire = cast(list[object], msg)
if typ == "EVENT" and len(msg) >= 3: typ: object = wire[0]
await self._on_event(url, msg[2]) if typ == "EVENT" and len(wire) >= 3 and isinstance(wire[2], dict):
elif typ == "NOTICE" and len(msg) >= 2 and self._on_notice: await self._on_event(url, cast(dict[str, Any], wire[2]))
await self._on_notice(url, str(msg[1])) elif typ == "NOTICE" and len(wire) >= 2 and self._on_notice:
await self._on_notice(url, str(wire[1]))
elif typ == "OK": elif typ == "OK":
pass pass
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001
@ -157,7 +158,7 @@ class RelayManager:
r.state = RelayState.ERROR r.state = RelayState.ERROR
log.warning("relay %s error: %s", url, e) log.warning("relay %s error: %s", url, e)
finally: finally:
r._ws = None r.ws = None
r.state = RelayState.BACKOFF r.state = RelayState.BACKOFF
attempt += 1 attempt += 1
delay = min(60.0, 1.5**attempt) + random.random() delay = min(60.0, 1.5**attempt) + random.random()

12
src/imwald/ui/composer_dialog.py

@ -4,7 +4,7 @@ from __future__ import annotations
import json import json
import time import time
from typing import Any from typing import Any, cast
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QComboBox, QComboBox,
@ -17,6 +17,7 @@ from PySide6.QtWidgets import (
QMessageBox, QMessageBox,
QSpinBox, QSpinBox,
QVBoxLayout, QVBoxLayout,
QWidget,
) )
from imwald.core.accounts_store import StoredAccount, unlock_secret from imwald.core.accounts_store import StoredAccount, unlock_secret
@ -32,7 +33,7 @@ TAG_SUGGESTIONS = ["t", "client", "e", "p", "relay", "imeta"]
class ComposerDialog(QDialog): class ComposerDialog(QDialog):
def __init__( def __init__(
self, self,
parent=None, parent: QWidget | None = None,
*, *,
edit_from: StoredEventRow | dict[str, Any] | None = None, edit_from: StoredEventRow | dict[str, Any] | None = None,
account: StoredAccount, account: StoredAccount,
@ -96,9 +97,10 @@ class ComposerDialog(QDialog):
def _publish(self) -> None: def _publish(self) -> None:
try: try:
tags = json.loads(self._tags.text() or "[]") tags_raw = json.loads(self._tags.text() or "[]")
if not isinstance(tags, list): if not isinstance(tags_raw, list):
raise ValueError("tags must be a JSON array") raise ValueError("tags must be a JSON array")
tags = cast(list[list[str]], tags_raw)
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001
QMessageBox.warning(self, "Invalid tags", str(e)) QMessageBox.warning(self, "Invalid tags", str(e))
return return
@ -117,7 +119,7 @@ class ComposerDialog(QDialog):
def open_composer_for_edit( def open_composer_for_edit(
parent, parent: QWidget | None,
ev: dict[str, Any], ev: dict[str, Any],
account: StoredAccount, account: StoredAccount,
password: str | None, password: str | None,

18
src/imwald/ui/db_admin_page.py

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from PySide6.QtCore import Qt, Signal from PySide6.QtCore import Signal
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QComboBox, QComboBox,
QHBoxLayout, QHBoxLayout,
@ -24,7 +24,9 @@ class DbAdminPage(QWidget):
open_event = Signal(str) open_event = Signal(str)
request_nip09 = Signal(str, str) # event_id, signing_pubkey hex request_nip09 = Signal(str, str) # event_id, signing_pubkey hex
def __init__(self, db: Database, accounts: list[StoredAccount], parent=None) -> None: def __init__(
self, db: Database, accounts: list[StoredAccount], parent: QWidget | None = None
) -> None:
super().__init__(parent) super().__init__(parent)
self._db = db self._db = db
self._accounts = accounts self._accounts = accounts
@ -83,7 +85,7 @@ class DbAdminPage(QWidget):
self._grid.setHorizontalHeaderLabels(cols) self._grid.setHorizontalHeaderLabels(cols)
self._grid.setRowCount(len(rows)) self._grid.setRowCount(len(rows))
for ri, row in enumerate(rows): for ri, row in enumerate(rows):
for ci, c in enumerate(cols): for ci, _ in enumerate(cols):
v = row[ci] v = row[ci]
self._grid.setItem(ri, ci, QTableWidgetItem("" if v is None else str(v))) self._grid.setItem(ri, ci, QTableWidgetItem("" if v is None else str(v)))
self._grid.setProperty("current_table", name) self._grid.setProperty("current_table", name)
@ -93,7 +95,7 @@ class DbAdminPage(QWidget):
name = self._grid.property("current_table") name = self._grid.property("current_table")
if name != "events": if name != "events":
return None return None
cols = [] cols: list[str] = []
for i in range(self._grid.columnCount()): for i in range(self._grid.columnCount()):
hi = self._grid.horizontalHeaderItem(i) hi = self._grid.horizontalHeaderItem(i)
cols.append(hi.text() if hi is not None else "") cols.append(hi.text() if hi is not None else "")
@ -111,12 +113,12 @@ class DbAdminPage(QWidget):
name = self._grid.property("current_table") name = self._grid.property("current_table")
if name != "events": if name != "events":
return None return None
cols = [] cols_pk: list[str] = []
for i in range(self._grid.columnCount()): for i in range(self._grid.columnCount()):
hi = self._grid.horizontalHeaderItem(i) hi = self._grid.horizontalHeaderItem(i)
cols.append(hi.text() if hi is not None else "") cols_pk.append(hi.text() if hi is not None else "")
try: try:
ci = cols.index("pubkey") ci = cols_pk.index("pubkey")
except ValueError: except ValueError:
return None return None
r = self._grid.currentRow() r = self._grid.currentRow()
@ -139,7 +141,7 @@ class DbAdminPage(QWidget):
self._grid.setHorizontalHeaderLabels(cols) self._grid.setHorizontalHeaderLabels(cols)
self._grid.setRowCount(len(rows)) self._grid.setRowCount(len(rows))
for ri, row in enumerate(rows): for ri, row in enumerate(rows):
for ci, c in enumerate(cols): for ci, _ in enumerate(cols):
v = row[ci] v = row[ci]
self._grid.setItem(ri, ci, QTableWidgetItem("" if v is None else str(v))) self._grid.setItem(ri, ci, QTableWidgetItem("" if v is None else str(v)))
self._nip_btn.setVisible(False) self._nip_btn.setVisible(False)

27
src/imwald/ui/feed_page.py

@ -4,10 +4,12 @@ from __future__ import annotations
import html import html
import json import json
from collections.abc import Sequence
from typing import Any, cast from typing import Any, cast
from PySide6.QtCore import QEvent, QObject, Qt, QTimer from PySide6.QtCore import QEvent, QObject, Qt, QTimer
from PySide6.QtGui import QKeyEvent, QTextOption from PySide6.QtGui import QKeyEvent, QTextOption
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QFrame, QFrame,
QHBoxLayout, QHBoxLayout,
@ -37,8 +39,6 @@ FEED_KINDS = (1, 20, 21, 30023, 9802, 11)
def _set_plain_height_to_content(te: QPlainTextEdit) -> None: def _set_plain_height_to_content(te: QPlainTextEdit) -> None:
doc = te.document() doc = te.document()
lay = doc.documentLayout() lay = doc.documentLayout()
if lay is None:
return
vw = te.viewport().width() vw = te.viewport().width()
if vw < 50: if vw < 50:
outer = max(te.width(), 120) outer = max(te.width(), 120)
@ -68,9 +68,26 @@ def _format_engagement_html(stats: dict[str, Any]) -> str:
parts.append(f"💬&nbsp;<b>{q}</b>") parts.append(f"💬&nbsp;<b>{q}</b>")
if rep: if rep:
parts.append(f"↩&nbsp;<b>{rep}</b>") parts.append(f"↩&nbsp;<b>{rep}</b>")
rx = stats.get("reaction_breakdown") or [] rx_raw = stats.get("reaction_breakdown")
pairs: list[tuple[str, int]] = []
if isinstance(rx_raw, list):
for pair_obj in cast(list[object], rx_raw)[:18]:
if not isinstance(pair_obj, (list, tuple)):
continue
pseq = cast(Sequence[object], pair_obj)
if len(pseq) < 2:
continue
em_o, c_o = pseq[0], pseq[1]
em = em_o if isinstance(em_o, str) else str(em_o)
if isinstance(c_o, bool):
c = int(c_o)
elif isinstance(c_o, (int, float)):
c = int(c_o)
else:
c = int(str(c_o)) if str(c_o).isdigit() else 0
pairs.append((em, c))
emoji_bits: list[str] = [] emoji_bits: list[str] = []
for em, c in rx[:18]: for em, c in pairs:
e = html.escape(em if em != "+" else "", quote=False) e = html.escape(em if em != "+" else "", quote=False)
if c > 1: if c > 1:
emoji_bits.append(f'<span style="font-size:21px" title="{e}×{c}">{e}<sub style="font-size:13px">{c}</sub></span>') emoji_bits.append(f'<span style="font-size:21px" title="{e}×{c}">{e}<sub style="font-size:13px">{c}</sub></span>')
@ -86,7 +103,7 @@ def _format_engagement_html(stats: dict[str, Any]) -> str:
class FeedPage(QWidget): class FeedPage(QWidget):
def __init__(self, db: Database, engine: NostrEngine, parent=None) -> None: def __init__(self, db: Database, engine: NostrEngine, parent: QWidget | None = None) -> None:
super().__init__(parent) super().__init__(parent)
self.setObjectName("FeedPage") self.setObjectName("FeedPage")
self._db = db self._db = db

18
src/imwald/ui/main_window.py

@ -2,8 +2,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, cast
from PySide6.QtCore import Qt, QTimer from PySide6.QtCore import Qt, QTimer
from PySide6.QtGui import QAction from PySide6.QtGui import QAction, QCloseEvent
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QComboBox, QComboBox,
QDialog, QDialog,
@ -17,6 +19,7 @@ from PySide6.QtWidgets import (
QStackedWidget, QStackedWidget,
QToolBar, QToolBar,
QVBoxLayout, QVBoxLayout,
QWidget,
) )
from imwald.core.accounts_store import StoredAccount, load_accounts from imwald.core.accounts_store import StoredAccount, load_accounts
@ -34,7 +37,7 @@ from imwald.ui.search_page import SearchPage
class MainWindow(QMainWindow): class MainWindow(QMainWindow):
def __init__(self, *, db: Database, engine: NostrEngine, parent=None) -> None: def __init__(self, *, db: Database, engine: NostrEngine, parent: QWidget | None = None) -> None:
super().__init__(parent) super().__init__(parent)
self.setWindowTitle("imwald") self.setWindowTitle("imwald")
self.resize(1200, 820) self.resize(1200, 820)
@ -160,7 +163,7 @@ class MainWindow(QMainWindow):
m_file.addAction(a_onb) m_file.addAction(a_onb)
m_view = self.menuBar().addMenu("&View") m_view = self.menuBar().addMenu("&View")
for i, (title, idx) in enumerate( for _, (title, idx) in enumerate(
[ [
("&Feed", 0), ("&Feed", 0),
("&Search", 1), ("&Search", 1),
@ -190,12 +193,15 @@ class MainWindow(QMainWindow):
def _wire_engine(self) -> None: def _wire_engine(self) -> None:
self._engine.event_ingested.connect(self._on_event_ingested) self._engine.event_ingested.connect(self._on_event_ingested)
self._engine.relay_status.connect(lambda s: self.statusBar().showMessage(s, 8000)) self._engine.relay_status.connect(self._relay_status_message)
def _relay_status_message(self, s: str) -> None:
self.statusBar().showMessage(s, 8000)
def _on_event_ingested(self, relay_url: str, ev: object) -> None: def _on_event_ingested(self, relay_url: str, ev: object) -> None:
if not isinstance(ev, dict): if not isinstance(ev, dict):
return return
NostrEngine.apply_ingest_to_db(self._db, ev, relay_url) NostrEngine.apply_ingest_to_db(self._db, cast(dict[str, Any], ev), relay_url)
self._ingest_ui_timer.start() self._ingest_ui_timer.start()
def _wire_pages(self) -> None: def _wire_pages(self) -> None:
@ -301,6 +307,6 @@ class MainWindow(QMainWindow):
return None, None return None, None
return acc, self._password_for(pk) return acc, self._password_for(pk)
def closeEvent(self, event) -> None: # noqa: N802 def closeEvent(self, event: QCloseEvent) -> None: # noqa: N802
self._engine.stop_relays() self._engine.stop_relays()
super().closeEvent(event) super().closeEvent(event)

4
src/imwald/ui/notifications_page.py

@ -16,7 +16,9 @@ class NotificationsPage(QWidget):
open_event = Signal(str) open_event = Signal(str)
signing_pubkey_changed = Signal(str) signing_pubkey_changed = Signal(str)
def __init__(self, db: Database, accounts: list[StoredAccount], parent=None) -> None: def __init__(
self, db: Database, accounts: list[StoredAccount], parent: QWidget | None = None
) -> None:
super().__init__(parent) super().__init__(parent)
self._db = db self._db = db
self._accounts = accounts self._accounts = accounts

38
src/imwald/ui/onboarding_wizard.py

@ -13,6 +13,7 @@ from PySide6.QtWidgets import (
QMessageBox, QMessageBox,
QPlainTextEdit, QPlainTextEdit,
QVBoxLayout, QVBoxLayout,
QWidget,
QWizard, QWizard,
QWizardPage, QWizardPage,
) )
@ -83,6 +84,9 @@ class PageProfile(QWizardPage):
self._about = QPlainTextEdit() self._about = QPlainTextEdit()
form.addRow("About", self._about) form.addRow("About", self._about)
def about_text(self) -> str:
return self._about.toPlainText().strip()
def nextId(self) -> int: def nextId(self) -> int:
return PAGE_INTERESTS return PAGE_INTERESTS
@ -100,6 +104,14 @@ class PageInterests(QWizardPage):
self._list.addItem(it) self._list.addItem(it)
lay.addWidget(self._list) lay.addWidget(self._list)
def selected_interests(self) -> list[str]:
out: list[str] = []
for i in range(self._list.count()):
it = self._list.item(i)
if it.checkState() == Qt.CheckState.Checked:
out.append(it.text().lstrip("#"))
return out
def nextId(self) -> int: def nextId(self) -> int:
return PAGE_LANG return PAGE_LANG
@ -144,6 +156,9 @@ class PageSafety(QWizardPage):
self._hide.setChecked(True) self._hide.setChecked(True)
lay.addWidget(self._hide) lay.addWidget(self._hide)
def hide_nsfw_recommended(self) -> bool:
return self._hide.isChecked()
def nextId(self) -> int: def nextId(self) -> int:
wiz = self.wizard() wiz = self.wizard()
intro = wiz.page(PAGE_INTRO) if wiz else None intro = wiz.page(PAGE_INTRO) if wiz else None
@ -165,9 +180,12 @@ class PagePassword(QWizardPage):
form.addRow("Password", self._pw) form.addRow("Password", self._pw)
form.addRow("Repeat", self._pw2) form.addRow("Repeat", self._pw2)
def password_pair(self) -> tuple[str, str]:
return self._pw.text(), self._pw2.text()
def run_onboarding_wizard( def run_onboarding_wizard(
parent, parent: QWidget | None,
*, *,
db: Database, db: Database,
engine: NostrEngine, engine: NostrEngine,
@ -192,14 +210,14 @@ def run_onboarding_wizard(
if w.exec() != QWizard.DialogCode.Accepted: if w.exec() != QWizard.DialogCode.Accepted:
return False return False
hide_nsfw = "1" if p4._hide.isChecked() else "0" # noqa: SLF001 hide_nsfw = "1" if p4.hide_nsfw_recommended() else "0"
db.set_setting("hide_nsfw", hide_nsfw) db.set_setting("hide_nsfw", hide_nsfw)
if p0.lurk(): # noqa: SLF001 if p0.lurk():
return True return True
pw = p5._pw.text() # noqa: SLF001 pw, pw2 = p5.password_pair()
if pw != p5._pw2.text(): # noqa: SLF001 if pw != pw2:
QMessageBox.warning(parent, "Password mismatch", "Passwords do not match.") QMessageBox.warning(parent, "Password mismatch", "Passwords do not match.")
return False return False
password = pw if pw else None password = pw if pw else None
@ -211,18 +229,14 @@ def run_onboarding_wizard(
existing_accounts.append(acc) existing_accounts.append(acc)
save_accounts(existing_accounts) save_accounts(existing_accounts)
interests = [] interests = p2.selected_interests()
for i in range(p2._list.count()): # noqa: SLF001
it = p2._list.item(i) # noqa: SLF001
if it.checkState() == Qt.CheckState.Checked:
interests.append(it.text().lstrip("#"))
langs = p3.selected() # noqa: SLF001 langs = p3.selected()
engine.publish_kind0_and_lists( engine.publish_kind0_and_lists(
acc, acc,
password, password,
username=nature_label, username=nature_label,
about=p1._about.toPlainText().strip(), about=p1.about_text(),
interest_tags=interests, interest_tags=interests,
languages=langs, languages=langs,
) )

2
src/imwald/ui/search_page.py

@ -12,7 +12,7 @@ from imwald.core.md_render import markdown_plain_summary
class SearchPage(QWidget): class SearchPage(QWidget):
open_event = Signal(str) open_event = Signal(str)
def __init__(self, db: Database, parent=None) -> None: def __init__(self, db: Database, parent: QWidget | None = None) -> None:
super().__init__(parent) super().__init__(parent)
self._db = db self._db = db
self._q = QLineEdit() self._q = QLineEdit()

19
tests/test_kind30000_lists.py

@ -1,5 +1,6 @@
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any
from imwald.core.database import Database from imwald.core.database import Database
from imwald.core.nostr_crypto import build_signed_event from imwald.core.nostr_crypto import build_signed_event
@ -33,8 +34,22 @@ def test_ranker_follow_beats_kind30000() -> None:
me = "f" * 64 me = "f" * 64
follow_pk = "a" * 64 follow_pk = "a" * 64
list_pk = "b" * 64 list_pk = "b" * 64
ev_f = {"id": "1" * 64, "pubkey": follow_pk, "created_at": 1, "kind": 1, "tags": [], "content": "x"} ev_f: dict[str, Any] = {
ev_l = {"id": "2" * 64, "pubkey": list_pk, "created_at": 2, "kind": 1, "tags": [], "content": "y"} "id": "1" * 64,
"pubkey": follow_pk,
"created_at": 1,
"kind": 1,
"tags": [],
"content": "x",
}
ev_l: dict[str, Any] = {
"id": "2" * 64,
"pubkey": list_pk,
"created_at": 2,
"kind": 1,
"tags": [],
"content": "y",
}
sf, _ = r.score_event(ev_f, my_pubkey=me, following={follow_pk}, list30000_pubkeys={list_pk}) sf, _ = r.score_event(ev_f, my_pubkey=me, following={follow_pk}, list30000_pubkeys={list_pk})
sl, _ = r.score_event(ev_l, my_pubkey=me, following={follow_pk}, list30000_pubkeys={list_pk}) sl, _ = r.score_event(ev_l, my_pubkey=me, following={follow_pk}, list30000_pubkeys={list_pk})
assert sf > sl assert sf > sl

5
typings/quickjs.pyi

@ -0,0 +1,5 @@
"""Partial stubs for the ``quickjs`` module (``quickjs-ng`` runtime)."""
class Context:
def __init__(self) -> None: ...
def eval(self, code: str, /) -> object: ...
Loading…
Cancel
Save