seshat-tts
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
"""Hotkey OCR capture to Pocket TTS."""
|
||||
|
||||
__all__ = ["__version__"]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,6 @@
|
||||
from .app import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mss
|
||||
from PIL import Image
|
||||
import win32gui
|
||||
import win32ui
|
||||
|
||||
from .config import Rect
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MonitorInfo:
|
||||
index: int
|
||||
left: int
|
||||
top: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"{self.index}: {self.width}x{self.height} at {self.left},{self.top}"
|
||||
|
||||
|
||||
def list_monitors() -> list[MonitorInfo]:
|
||||
with mss.mss() as sct:
|
||||
return [
|
||||
MonitorInfo(
|
||||
index=index,
|
||||
left=int(monitor["left"]),
|
||||
top=int(monitor["top"]),
|
||||
width=int(monitor["width"]),
|
||||
height=int(monitor["height"]),
|
||||
)
|
||||
for index, monitor in enumerate(sct.monitors)
|
||||
if index != 0
|
||||
]
|
||||
|
||||
|
||||
def capture_absolute_region(left: int, top: int, width: int, height: int) -> Image.Image:
|
||||
with mss.mss() as sct:
|
||||
grab = {
|
||||
"left": left,
|
||||
"top": top,
|
||||
"width": width,
|
||||
"height": height,
|
||||
}
|
||||
shot = sct.grab(grab)
|
||||
return Image.frombytes("RGB", shot.size, shot.rgb)
|
||||
|
||||
|
||||
def capture_monitor_region(monitor_index: int, rect: Rect) -> Image.Image:
|
||||
with mss.mss() as sct:
|
||||
if monitor_index <= 0 or monitor_index >= len(sct.monitors):
|
||||
raise ValueError(f"Monitor {monitor_index} is not available.")
|
||||
monitor = sct.monitors[monitor_index]
|
||||
return capture_absolute_region(
|
||||
int(monitor["left"]) + rect.left,
|
||||
int(monitor["top"]) + rect.top,
|
||||
rect.width,
|
||||
rect.height,
|
||||
)
|
||||
|
||||
|
||||
def capture_window_region(hwnd: int, rect: Rect) -> Image.Image:
|
||||
image = capture_window(hwnd)
|
||||
if rect.left < 0 or rect.top < 0 or rect.width <= 0 or rect.height <= 0:
|
||||
raise ValueError("Capture region must be inside the selected window.")
|
||||
if rect.left + rect.width > image.width or rect.top + rect.height > image.height:
|
||||
raise ValueError("Capture region is outside the selected window. Select the region again in window mode.")
|
||||
return image.crop((rect.left, rect.top, rect.left + rect.width, rect.top + rect.height))
|
||||
|
||||
|
||||
def capture_window(hwnd: int) -> Image.Image:
|
||||
left, top, right, bottom = win32gui.GetWindowRect(hwnd)
|
||||
width = right - left
|
||||
height = bottom - top
|
||||
if width <= 0 or height <= 0:
|
||||
raise ValueError("Selected window has no capturable size.")
|
||||
|
||||
hwnd_dc = win32gui.GetWindowDC(hwnd)
|
||||
source_dc = win32ui.CreateDCFromHandle(hwnd_dc)
|
||||
memory_dc = source_dc.CreateCompatibleDC()
|
||||
bitmap = win32ui.CreateBitmap()
|
||||
bitmap.CreateCompatibleBitmap(source_dc, width, height)
|
||||
memory_dc.SelectObject(bitmap)
|
||||
|
||||
try:
|
||||
result = _print_window(hwnd, memory_dc.GetSafeHdc(), 2)
|
||||
if result != 1:
|
||||
result = _print_window(hwnd, memory_dc.GetSafeHdc(), 0)
|
||||
if result != 1:
|
||||
raise RuntimeError("PrintWindow failed for the selected window.")
|
||||
|
||||
info = bitmap.GetInfo()
|
||||
bits = bitmap.GetBitmapBits(True)
|
||||
return Image.frombuffer(
|
||||
"RGB",
|
||||
(info["bmWidth"], info["bmHeight"]),
|
||||
bits,
|
||||
"raw",
|
||||
"BGRX",
|
||||
0,
|
||||
1,
|
||||
).copy()
|
||||
finally:
|
||||
win32gui.DeleteObject(bitmap.GetHandle())
|
||||
memory_dc.DeleteDC()
|
||||
source_dc.DeleteDC()
|
||||
win32gui.ReleaseDC(hwnd, hwnd_dc)
|
||||
|
||||
|
||||
def _print_window(hwnd: int, hdc: int, flags: int) -> int:
|
||||
return int(ctypes.windll.user32.PrintWindow(hwnd, hdc, flags))
|
||||
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .tesseract import find_tesseract
|
||||
|
||||
|
||||
APP_DIR = Path.home() / ".seshat-tts"
|
||||
CONFIG_PATH = APP_DIR / "config.json"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Rect:
|
||||
left: int = 0
|
||||
top: int = 25
|
||||
width: int = 720
|
||||
height: int = 305
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AppConfig:
|
||||
capture_mode: str = "monitor"
|
||||
monitor_index: int = 1
|
||||
window_title: str = ""
|
||||
hotkey: str = "ctrl+alt+n"
|
||||
capture_region_hotkey: str = "ctrl+alt+r"
|
||||
stop_hotkey: str = "ctrl+alt+s"
|
||||
dialogue_rect: Rect = field(default_factory=Rect)
|
||||
tesseract_cmd: str = field(default_factory=find_tesseract)
|
||||
voice_source: str = "default"
|
||||
default_voice: str = "alba"
|
||||
custom_voice_name: str = ""
|
||||
voice_path: str = ""
|
||||
language: str = "english"
|
||||
quantize_tts: bool = False
|
||||
volume_gain: float = 1.0
|
||||
tts_backend: str = "uvx-server"
|
||||
tts_host: str = "localhost"
|
||||
tts_port: int = 8000
|
||||
llm_enabled: bool = False
|
||||
llm_base_url: str = "http://127.0.0.1:8000/v1"
|
||||
llm_api_key: str = ""
|
||||
llm_model: str = "current"
|
||||
llm_timeout: float = 5.0
|
||||
llm_max_tokens: int = 256
|
||||
llm_disable_thinking: bool = True
|
||||
llm_image_extraction: bool = False
|
||||
llm_system_prompt: str = (
|
||||
"Clean OCR text for text-to-speech. Return only the corrected text. "
|
||||
"Do not explain, add commentary, summarize, or change the meaning."
|
||||
)
|
||||
last_text: str = ""
|
||||
|
||||
|
||||
def _rect_from_dict(value: dict[str, Any] | None) -> Rect:
|
||||
if not value:
|
||||
return Rect()
|
||||
return Rect(**{field: int(value.get(field, getattr(Rect(), field))) for field in Rect.__dataclass_fields__})
|
||||
|
||||
|
||||
def _clean_last_text(value: Any) -> str:
|
||||
lines = str(value or "").splitlines()
|
||||
cleaned = [
|
||||
line
|
||||
for line in lines
|
||||
if not line.strip().casefold().startswith(("capture region:", "text region:"))
|
||||
]
|
||||
return "\n".join(cleaned).strip()
|
||||
|
||||
|
||||
def _tesseract_from_config(value: Any) -> str:
|
||||
detected = find_tesseract()
|
||||
if getattr(sys, "frozen", False) and detected:
|
||||
return detected
|
||||
return str(value or detected)
|
||||
|
||||
|
||||
def load_config(path: Path = CONFIG_PATH) -> AppConfig:
|
||||
if not path.exists():
|
||||
return AppConfig()
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
return AppConfig(
|
||||
capture_mode=str(data.get("capture_mode", "monitor")),
|
||||
monitor_index=int(data.get("monitor_index", 1)),
|
||||
window_title=str(data.get("window_title", "")),
|
||||
hotkey=str(data.get("hotkey", "ctrl+alt+n")),
|
||||
capture_region_hotkey=str(data.get("capture_region_hotkey", "ctrl+alt+r")),
|
||||
stop_hotkey=str(data.get("stop_hotkey", "ctrl+alt+s")),
|
||||
dialogue_rect=_rect_from_dict(data.get("dialogue_rect")),
|
||||
tesseract_cmd=_tesseract_from_config(data.get("tesseract_cmd")),
|
||||
voice_source=str(data.get("voice_source", "default")),
|
||||
default_voice=str(data.get("default_voice", "alba")),
|
||||
custom_voice_name=str(data.get("custom_voice_name", "")),
|
||||
voice_path=str(data.get("voice_path", "")),
|
||||
language="english",
|
||||
quantize_tts=bool(data.get("quantize_tts", False)),
|
||||
volume_gain=float(data.get("volume_gain", 1.0)),
|
||||
tts_backend=str(data.get("tts_backend", "uvx-server")),
|
||||
tts_host=str(data.get("tts_host", "localhost")),
|
||||
tts_port=int(data.get("tts_port", 8000)),
|
||||
llm_enabled=bool(data.get("llm_enabled", False)),
|
||||
llm_base_url=str(data.get("llm_base_url", "http://127.0.0.1:8000/v1")),
|
||||
llm_api_key=str(data.get("llm_api_key", "")),
|
||||
llm_model=str(data.get("llm_model", "unsloth")),
|
||||
llm_timeout=float(data.get("llm_timeout", 5.0)),
|
||||
llm_max_tokens=int(data.get("llm_max_tokens", 256)),
|
||||
llm_disable_thinking=bool(data.get("llm_disable_thinking", True)),
|
||||
llm_image_extraction=bool(data.get("llm_image_extraction", False)),
|
||||
llm_system_prompt=str(
|
||||
data.get(
|
||||
"llm_system_prompt",
|
||||
AppConfig.__dataclass_fields__["llm_system_prompt"].default,
|
||||
)
|
||||
),
|
||||
last_text=_clean_last_text(data.get("last_text", "")),
|
||||
)
|
||||
|
||||
|
||||
def save_config(config: AppConfig, path: Path = CONFIG_PATH) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(asdict(config), indent=2), encoding="utf-8")
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import keyboard
|
||||
|
||||
|
||||
class HotkeyManager:
|
||||
def __init__(self) -> None:
|
||||
self._handles: dict[str, object] = {}
|
||||
|
||||
def register(self, name: str, hotkey: str, callback: Callable[[], None]) -> None:
|
||||
self.unregister(name)
|
||||
if not hotkey.strip():
|
||||
return
|
||||
self._handles[name] = keyboard.add_hotkey(hotkey, callback, suppress=False, trigger_on_release=False)
|
||||
|
||||
def unregister(self, name: str | None = None) -> None:
|
||||
if name is not None:
|
||||
handle = self._handles.pop(name, None)
|
||||
if handle is not None:
|
||||
keyboard.remove_hotkey(handle)
|
||||
return
|
||||
for handle in self._handles.values():
|
||||
keyboard.remove_hotkey(handle)
|
||||
self._handles.clear()
|
||||
|
||||
|
||||
def listen_for_hotkey() -> str:
|
||||
return keyboard.read_hotkey(suppress=False)
|
||||
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
DEFAULT_API_KEY_PATH = Path.home() / ".seshat-tts" / "llm_api_key.txt"
|
||||
IMAGE_EXTRACTION_SYSTEM_PROMPT = (
|
||||
"Extract only the visible readable text from the supplied image for text-to-speech. "
|
||||
"Preserve the original wording and sentence order. Do not describe the image, "
|
||||
"do not add commentary, and do not include UI labels unless they are part of the text to read."
|
||||
)
|
||||
IMAGE_EXTRACTION_USER_PROMPT = "Read the text in this selected screen region and return only that text."
|
||||
|
||||
|
||||
class _ChatCompletions(Protocol):
|
||||
def create(self, **kwargs: object) -> object: ...
|
||||
|
||||
|
||||
class _Chat(Protocol):
|
||||
completions: _ChatCompletions
|
||||
|
||||
|
||||
class _OpenAIClient(Protocol):
|
||||
chat: _Chat
|
||||
|
||||
|
||||
def load_api_key_file(path: Path = DEFAULT_API_KEY_PATH) -> str:
|
||||
if not path.exists():
|
||||
return ""
|
||||
return path.read_text(encoding="utf-8").strip()
|
||||
|
||||
|
||||
def process_text_with_llm(
|
||||
text: str,
|
||||
*,
|
||||
enabled: bool,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
timeout: float = 5.0,
|
||||
max_tokens: int = 256,
|
||||
disable_thinking: bool = True,
|
||||
client: _OpenAIClient | None = None,
|
||||
) -> str:
|
||||
text = text.strip()
|
||||
if not enabled or not text:
|
||||
return text
|
||||
|
||||
if client is None:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key.strip() or "local",
|
||||
base_url=base_url.strip(),
|
||||
timeout=max(0.1, float(timeout)),
|
||||
)
|
||||
|
||||
request: dict[str, object] = {
|
||||
"model": model.strip(),
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt.strip()},
|
||||
{"role": "user", "content": text},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": max(1, int(max_tokens)),
|
||||
"stream": False,
|
||||
}
|
||||
if disable_thinking:
|
||||
request["extra_body"] = {
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
"enable_thinking": False,
|
||||
"reasoning_effort": "none",
|
||||
}
|
||||
|
||||
response = client.chat.completions.create(**request)
|
||||
content = response.choices[0].message.content
|
||||
return str(content or "").strip() or text
|
||||
|
||||
|
||||
def process_image_with_llm(
|
||||
image: Image.Image,
|
||||
*,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float = 5.0,
|
||||
max_tokens: int = 256,
|
||||
disable_thinking: bool = True,
|
||||
client: _OpenAIClient | None = None,
|
||||
) -> str:
|
||||
if client is None:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key.strip() or "local",
|
||||
base_url=base_url.strip(),
|
||||
timeout=max(0.1, float(timeout)),
|
||||
)
|
||||
|
||||
request: dict[str, object] = {
|
||||
"model": model.strip(),
|
||||
"messages": [
|
||||
{"role": "system", "content": IMAGE_EXTRACTION_SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": IMAGE_EXTRACTION_USER_PROMPT},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{_image_to_base64_png(image)}",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": max(1, int(max_tokens)),
|
||||
"stream": False,
|
||||
}
|
||||
if disable_thinking:
|
||||
request["extra_body"] = {
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
"enable_thinking": False,
|
||||
"reasoning_effort": "none",
|
||||
}
|
||||
|
||||
response = client.chat.completions.create(**request)
|
||||
content = response.choices[0].message.content
|
||||
return str(content or "").strip()
|
||||
|
||||
|
||||
def _image_to_base64_png(image: Image.Image) -> str:
|
||||
buffer = BytesIO()
|
||||
image.convert("RGB").save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("ascii")
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
|
||||
|
||||
from .tesseract import tesseract_help_message
|
||||
|
||||
|
||||
def preprocess_for_ocr(image: Image.Image) -> Image.Image:
|
||||
image = ImageOps.expand(image, border=12, fill=(0, 0, 0))
|
||||
gray = ImageOps.grayscale(image)
|
||||
enlarged = gray.resize((gray.width * 2, gray.height * 2), Image.Resampling.LANCZOS)
|
||||
contrast = ImageEnhance.Contrast(enlarged).enhance(2.2)
|
||||
sharpened = contrast.filter(ImageFilter.SHARPEN)
|
||||
return sharpened.point(lambda pixel: 255 if pixel > 145 else 0)
|
||||
|
||||
|
||||
def image_to_lines(image: Image.Image, tesseract_cmd: str = "") -> list[str]:
|
||||
import pytesseract
|
||||
from pytesseract import TesseractNotFoundError
|
||||
|
||||
if tesseract_cmd:
|
||||
pytesseract.pytesseract.tesseract_cmd = tesseract_cmd
|
||||
tessdata = _tessdata_dir(tesseract_cmd)
|
||||
if tessdata is not None:
|
||||
os.environ["TESSDATA_PREFIX"] = str(tessdata)
|
||||
config = "--psm 6 --oem 3"
|
||||
try:
|
||||
text = pytesseract.image_to_string(image, lang="eng", config=config)
|
||||
except TesseractNotFoundError as exc:
|
||||
raise RuntimeError(tesseract_help_message()) from exc
|
||||
return [normalize_line(line) for line in text.splitlines() if normalize_line(line)]
|
||||
|
||||
|
||||
def normalize_line(line: str) -> str:
|
||||
import re
|
||||
|
||||
line = re.sub(r"\s+", " ", line).strip()
|
||||
line = line.replace("“", '"').replace("”", '"').replace("‘", "'").replace("’", "'")
|
||||
return line
|
||||
|
||||
|
||||
def extract_text_from_lines(lines: list[str]) -> str:
|
||||
return " ".join(lines).strip()
|
||||
|
||||
|
||||
def extract_ocr_text(image: Image.Image, tesseract_cmd: str = "") -> str:
|
||||
processed = preprocess_for_ocr(image)
|
||||
return extract_text_from_lines(image_to_lines(processed, tesseract_cmd))
|
||||
|
||||
|
||||
def _tessdata_dir(tesseract_cmd: str) -> Path | None:
|
||||
if not tesseract_cmd:
|
||||
return None
|
||||
tessdata = Path(tesseract_cmd).resolve().parent / "tessdata"
|
||||
if tessdata.exists():
|
||||
return tessdata
|
||||
return None
|
||||
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tkinter as tk
|
||||
from collections.abc import Callable
|
||||
|
||||
from PIL import ImageEnhance, ImageTk
|
||||
|
||||
from .capture import capture_absolute_region
|
||||
from .config import Rect
|
||||
|
||||
|
||||
class RegionPicker(tk.Toplevel):
|
||||
def __init__(
|
||||
self,
|
||||
parent: tk.Tk,
|
||||
bounds: Rect,
|
||||
on_selected: Callable[[Rect], None],
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
self._bounds = bounds
|
||||
self._on_selected = on_selected
|
||||
self._start_x = 0
|
||||
self._start_y = 0
|
||||
self._rect_id: int | None = None
|
||||
self._label_id: int | None = None
|
||||
|
||||
self.overrideredirect(True)
|
||||
self.attributes("-topmost", True)
|
||||
self.geometry(f"{bounds.width}x{bounds.height}{bounds.left:+d}{bounds.top:+d}")
|
||||
self.configure(cursor="crosshair")
|
||||
|
||||
screenshot = capture_absolute_region(bounds.left, bounds.top, bounds.width, bounds.height)
|
||||
dimmed = ImageEnhance.Brightness(screenshot).enhance(0.55)
|
||||
self._image = ImageTk.PhotoImage(dimmed)
|
||||
|
||||
self.canvas = tk.Canvas(self, bg="#050608", highlightthickness=0, cursor="crosshair")
|
||||
self.canvas.pack(fill=tk.BOTH, expand=True)
|
||||
self.canvas.create_image(0, 0, image=self._image, anchor=tk.NW)
|
||||
self.canvas.create_text(
|
||||
18,
|
||||
18,
|
||||
text="Drag to select. Esc cancels.",
|
||||
fill="#f4f7fb",
|
||||
anchor=tk.NW,
|
||||
font=("Segoe UI", 12, "bold"),
|
||||
)
|
||||
self.canvas.bind("<ButtonPress-1>", self._on_press)
|
||||
self.canvas.bind("<B1-Motion>", self._on_drag)
|
||||
self.canvas.bind("<ButtonRelease-1>", self._on_release)
|
||||
self.bind("<Escape>", lambda _event: self.destroy())
|
||||
|
||||
self.focus_force()
|
||||
self.grab_set()
|
||||
|
||||
def _on_press(self, event: tk.Event) -> None:
|
||||
self._start_x = int(event.x)
|
||||
self._start_y = int(event.y)
|
||||
self._rect_id = self.canvas.create_rectangle(
|
||||
self._start_x,
|
||||
self._start_y,
|
||||
self._start_x,
|
||||
self._start_y,
|
||||
outline="#ff365f",
|
||||
width=3,
|
||||
)
|
||||
|
||||
def _on_drag(self, event: tk.Event) -> None:
|
||||
if self._rect_id is not None:
|
||||
x1, x2 = sorted((self._start_x, int(event.x)))
|
||||
y1, y2 = sorted((self._start_y, int(event.y)))
|
||||
self.canvas.coords(self._rect_id, x1, y1, x2, y2)
|
||||
label = f"{x2 - x1} x {y2 - y1}"
|
||||
if self._label_id is None:
|
||||
self._label_id = self.canvas.create_text(
|
||||
x1 + 8,
|
||||
max(12, y1 - 18),
|
||||
text=label,
|
||||
fill="#f4f7fb",
|
||||
anchor=tk.W,
|
||||
font=("Segoe UI", 10, "bold"),
|
||||
)
|
||||
else:
|
||||
self.canvas.coords(self._label_id, x1 + 8, max(12, y1 - 18))
|
||||
self.canvas.itemconfigure(self._label_id, text=label)
|
||||
|
||||
def _on_release(self, event: tk.Event) -> None:
|
||||
x1, x2 = sorted((self._start_x, int(event.x)))
|
||||
y1, y2 = sorted((self._start_y, int(event.y)))
|
||||
self.grab_release()
|
||||
self.destroy()
|
||||
if x2 - x1 < 4 or y2 - y1 < 4:
|
||||
return
|
||||
self._on_selected(
|
||||
Rect(
|
||||
left=self._bounds.left + x1,
|
||||
top=self._bounds.top + y1,
|
||||
width=x2 - x1,
|
||||
height=y2 - y1,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def resource_path(relative_path: str) -> Path:
|
||||
base = Path(getattr(sys, "_MEIPASS", Path(__file__).resolve().parents[2]))
|
||||
return base / relative_path
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from .resources import resource_path
|
||||
|
||||
|
||||
COMMON_TESSERACT_PATHS = (
|
||||
Path(r"C:\Program Files\Tesseract-OCR\tesseract.exe"),
|
||||
Path(r"C:\Program Files (x86)\Tesseract-OCR\tesseract.exe"),
|
||||
)
|
||||
|
||||
|
||||
def find_tesseract() -> str:
|
||||
bundled = resource_path("tesseract/tesseract.exe")
|
||||
if bundled.exists():
|
||||
return str(bundled)
|
||||
from_path = shutil.which("tesseract")
|
||||
if from_path:
|
||||
return from_path
|
||||
for path in COMMON_TESSERACT_PATHS:
|
||||
if path.exists():
|
||||
return str(path)
|
||||
return ""
|
||||
|
||||
|
||||
def tesseract_help_message() -> str:
|
||||
return (
|
||||
"Tesseract OCR is not installed or the executable is not configured. "
|
||||
"Install it with `winget install UB-Mannheim.TesseractOCR`, then restart the app, "
|
||||
"or select tesseract.exe in the GUI."
|
||||
)
|
||||
@@ -0,0 +1,576 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import hashlib
|
||||
import http.server
|
||||
import importlib
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote, urljoin
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import sounddevice as sd
|
||||
|
||||
from .resources import resource_path
|
||||
from .voices import safe_voice_slug
|
||||
|
||||
|
||||
VOICE_CACHE_DIR = Path.home() / ".seshat-tts" / "voices"
|
||||
|
||||
|
||||
class PocketTTSStreamer:
|
||||
def __init__(
|
||||
self,
|
||||
voice_path: str | Path,
|
||||
language: str = "english",
|
||||
quantize: bool = False,
|
||||
voice_source: str = "default",
|
||||
default_voice: str = "alba",
|
||||
custom_voice_name: str = "",
|
||||
volume_gain: float = 1.0,
|
||||
) -> None:
|
||||
self.voice_path = str(voice_path)
|
||||
self.language = language
|
||||
self.quantize = quantize
|
||||
self.voice_source = voice_source
|
||||
self.default_voice = default_voice
|
||||
self.custom_voice_name = custom_voice_name
|
||||
self.volume_gain = _clamp_volume_gain(volume_gain)
|
||||
self._model = None
|
||||
self._voice_state = None
|
||||
self._lock = threading.Lock()
|
||||
self._cancel_lock = threading.Lock()
|
||||
self._cancel_event = threading.Event()
|
||||
self._status_queue: queue.Queue[str] = queue.Queue()
|
||||
|
||||
@property
|
||||
def status_queue(self) -> queue.Queue[str]:
|
||||
return self._status_queue
|
||||
|
||||
def speak_async(self, text: str) -> None:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
cancel_event = self._begin_new_stream()
|
||||
threading.Thread(target=self._speak, args=(text, cancel_event), daemon=True).start()
|
||||
|
||||
def preload_async(self) -> None:
|
||||
threading.Thread(target=self._preload, daemon=True).start()
|
||||
|
||||
def test_async(self) -> None:
|
||||
self.speak_async("This is a Pocket TTS test.")
|
||||
|
||||
def close(self) -> None:
|
||||
self.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
with self._cancel_lock:
|
||||
self._cancel_event.set()
|
||||
|
||||
def _begin_new_stream(self) -> threading.Event:
|
||||
with self._cancel_lock:
|
||||
self._cancel_event.set()
|
||||
self._cancel_event = threading.Event()
|
||||
return self._cancel_event
|
||||
|
||||
def _preload(self) -> None:
|
||||
with self._lock:
|
||||
try:
|
||||
self._load()
|
||||
except Exception as exc:
|
||||
self._status_queue.put(f"TTS preload error: {exc}")
|
||||
|
||||
def _load(self) -> None:
|
||||
if self._model is not None and self._voice_state is not None:
|
||||
return
|
||||
self._status_queue.put("Loading Pocket TTS model...")
|
||||
try:
|
||||
pocket_tts = importlib.import_module("pocket_tts")
|
||||
tts_model = getattr(pocket_tts, "TTSModel")
|
||||
except (ImportError, OSError) as exc:
|
||||
raise RuntimeError(
|
||||
"Pocket TTS failed to load through the in-process Python API. "
|
||||
"Use the uvx-server backend, especially from the bundled EXE."
|
||||
) from exc
|
||||
|
||||
try:
|
||||
self._model = tts_model.load_model(language=self.language, quantize=self.quantize)
|
||||
except OSError as exc:
|
||||
raise RuntimeError(
|
||||
"Pocket TTS/Torch DLL initialization failed in the in-process Python API. "
|
||||
"Use the uvx-server backend instead."
|
||||
) from exc
|
||||
voice = self.default_voice if self.voice_source == "default" else self._custom_voice_path()
|
||||
self._status_queue.put(f"Loading voice: {voice}")
|
||||
self._voice_state = self._model.get_state_for_audio_prompt(voice)
|
||||
self._status_queue.put("Pocket TTS ready.")
|
||||
|
||||
def _custom_voice_path(self) -> str:
|
||||
if not self.voice_path.strip():
|
||||
raise ValueError("Select a WAV or MP3 file, or change Voice Source to default.")
|
||||
return str(_prepared_audio_prompt_path(self.voice_path, self.language, self._status_queue))
|
||||
|
||||
def _speak(self, text: str, cancel_event: threading.Event) -> None:
|
||||
with self._lock:
|
||||
try:
|
||||
self._load()
|
||||
if cancel_event.is_set():
|
||||
self._status_queue.put("Stopped previous TTS stream.")
|
||||
return
|
||||
assert self._model is not None
|
||||
assert self._voice_state is not None
|
||||
sample_rate = int(self._model.sample_rate)
|
||||
self._status_queue.put("Speaking OCR text...")
|
||||
with sd.OutputStream(samplerate=sample_rate, channels=1, dtype="float32") as stream:
|
||||
for chunk in self._model.generate_audio_stream(self._voice_state, text):
|
||||
if cancel_event.is_set():
|
||||
self._status_queue.put("Stopped previous TTS stream.")
|
||||
return
|
||||
audio = chunk.detach().cpu().numpy()
|
||||
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
|
||||
if audio.size:
|
||||
stream.write(_apply_volume_gain(audio, self.volume_gain))
|
||||
self._status_queue.put("Done.")
|
||||
except Exception as exc:
|
||||
self._status_queue.put(f"TTS error: {exc}")
|
||||
|
||||
|
||||
class UvxPocketTTSServer:
|
||||
def __init__(
|
||||
self,
|
||||
voice_path: str | Path,
|
||||
language: str = "english",
|
||||
quantize: bool = False,
|
||||
host: str = "localhost",
|
||||
port: int = 8000,
|
||||
voice_source: str = "default",
|
||||
default_voice: str = "alba",
|
||||
custom_voice_name: str = "",
|
||||
volume_gain: float = 1.0,
|
||||
) -> None:
|
||||
self.voice_path = str(voice_path)
|
||||
self.language = language
|
||||
self.quantize = quantize
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.voice_source = voice_source
|
||||
self.default_voice = default_voice
|
||||
self.custom_voice_name = custom_voice_name
|
||||
self.volume_gain = _clamp_volume_gain(volume_gain)
|
||||
self._process: subprocess.Popen[str] | None = None
|
||||
self._lock = threading.Lock()
|
||||
self._speak_lock = threading.Lock()
|
||||
self._cancel_lock = threading.Lock()
|
||||
self._cancel_event = threading.Event()
|
||||
self._active_response: requests.Response | None = None
|
||||
self._server_output: collections.deque[str] = collections.deque(maxlen=80)
|
||||
self._status_queue: queue.Queue[str] = queue.Queue()
|
||||
|
||||
@property
|
||||
def status_queue(self) -> queue.Queue[str]:
|
||||
return self._status_queue
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}/"
|
||||
|
||||
def preload_async(self) -> None:
|
||||
threading.Thread(target=self._ensure_server, daemon=True).start()
|
||||
|
||||
def speak_async(self, text: str) -> None:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
cancel_event = self._begin_new_stream()
|
||||
threading.Thread(target=self._speak, args=(text, cancel_event), daemon=True).start()
|
||||
|
||||
def test_async(self) -> None:
|
||||
self.speak_async("This is a Pocket TTS test.")
|
||||
|
||||
def close(self) -> None:
|
||||
self.stop()
|
||||
if self._process and self._process.poll() is None:
|
||||
self._process.terminate()
|
||||
|
||||
def stop(self) -> None:
|
||||
with self._cancel_lock:
|
||||
self._cancel_event.set()
|
||||
if self._active_response is not None:
|
||||
self._active_response.close()
|
||||
|
||||
def _begin_new_stream(self) -> threading.Event:
|
||||
with self._cancel_lock:
|
||||
self._cancel_event.set()
|
||||
if self._active_response is not None:
|
||||
self._active_response.close()
|
||||
self._cancel_event = threading.Event()
|
||||
return self._cancel_event
|
||||
|
||||
def _is_healthy(self) -> bool:
|
||||
try:
|
||||
response = requests.get(urljoin(self.base_url, "health"), timeout=2)
|
||||
return response.ok
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def _ensure_server(self) -> None:
|
||||
with self._lock:
|
||||
if self._is_healthy():
|
||||
self._status_queue.put("Pocket TTS server ready.")
|
||||
return
|
||||
if self._process is None or self._process.poll() is not None:
|
||||
uvx = _find_uvx()
|
||||
command = [
|
||||
str(uvx),
|
||||
"pocket-tts",
|
||||
"serve",
|
||||
"--host",
|
||||
self.host,
|
||||
"--port",
|
||||
str(self.port),
|
||||
"--language",
|
||||
self.language,
|
||||
]
|
||||
if self.quantize:
|
||||
command.append("--quantize")
|
||||
self._server_output.clear()
|
||||
self._status_queue.put(f"Starting Pocket TTS server with {uvx}...")
|
||||
self._process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
env=_clean_subprocess_env(),
|
||||
cwd=str(Path.home()),
|
||||
creationflags=_subprocess_creationflags(),
|
||||
)
|
||||
threading.Thread(target=self._read_server_output, daemon=True).start()
|
||||
deadline = time.monotonic() + 900
|
||||
while time.monotonic() < deadline:
|
||||
if self._is_healthy():
|
||||
self._status_queue.put("Pocket TTS server ready.")
|
||||
return
|
||||
if self._process and self._process.poll() is not None:
|
||||
output = self._server_output_tail()
|
||||
detail = f"\n{output}" if output else " No server output was captured."
|
||||
raise RuntimeError(f"Pocket TTS server exited with code {self._process.returncode}.{detail}")
|
||||
time.sleep(1)
|
||||
raise TimeoutError("Pocket TTS server did not become ready before timeout.")
|
||||
|
||||
def _read_server_output(self) -> None:
|
||||
process = self._process
|
||||
if process is None or process.stdout is None:
|
||||
return
|
||||
try:
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._server_output.append(line)
|
||||
except Exception as exc:
|
||||
self._server_output.append(f"Failed to read server output: {exc}")
|
||||
|
||||
def _server_output_tail(self) -> str:
|
||||
if not self._server_output:
|
||||
return ""
|
||||
return "\n".join(list(self._server_output)[-12:])
|
||||
|
||||
def _speak(self, text: str, cancel_event: threading.Event) -> None:
|
||||
with self._speak_lock:
|
||||
if cancel_event.is_set():
|
||||
self._status_queue.put("Stopped previous TTS stream.")
|
||||
return
|
||||
try:
|
||||
self._ensure_server()
|
||||
if cancel_event.is_set():
|
||||
self._status_queue.put("Stopped previous TTS stream.")
|
||||
return
|
||||
self._status_queue.put("Requesting Pocket TTS audio...")
|
||||
if self.voice_source == "default":
|
||||
response = requests.post(
|
||||
urljoin(self.base_url, "tts"),
|
||||
data={"text": text, "voice_url": self.default_voice},
|
||||
stream=True,
|
||||
timeout=900,
|
||||
)
|
||||
else:
|
||||
voice_url = self._custom_voice_url()
|
||||
response = requests.post(
|
||||
urljoin(self.base_url, "tts"),
|
||||
data={"text": text, "voice_url": voice_url},
|
||||
stream=True,
|
||||
timeout=900,
|
||||
)
|
||||
with self._cancel_lock:
|
||||
self._active_response = response
|
||||
response.raise_for_status()
|
||||
self._play_streaming_wav(response, cancel_event)
|
||||
if not cancel_event.is_set():
|
||||
self._status_queue.put("Done.")
|
||||
except requests.RequestException as exc:
|
||||
if cancel_event.is_set():
|
||||
self._status_queue.put("Stopped previous TTS stream.")
|
||||
else:
|
||||
self._status_queue.put(f"TTS error: {exc}")
|
||||
except Exception as exc:
|
||||
self._status_queue.put(f"TTS error: {exc}")
|
||||
finally:
|
||||
with self._cancel_lock:
|
||||
self._active_response = None
|
||||
|
||||
def _custom_voice_path(self) -> str:
|
||||
if not self.voice_path.strip():
|
||||
raise ValueError("Select a WAV or MP3 file, or change Voice Source to default.")
|
||||
return self.voice_path
|
||||
|
||||
def _custom_voice_url(self) -> str:
|
||||
voice_state = _cached_voice_state_path(
|
||||
self._custom_voice_path(),
|
||||
self.language,
|
||||
self._status_queue,
|
||||
self.custom_voice_name,
|
||||
)
|
||||
return _voice_state_server.url_for(voice_state)
|
||||
|
||||
def _play_streaming_wav(self, response: requests.Response, cancel_event: threading.Event) -> None:
|
||||
buffer = bytearray()
|
||||
stream: sd.OutputStream | None = None
|
||||
sample_width = 0
|
||||
channels = 0
|
||||
try:
|
||||
for chunk in response.iter_content(chunk_size=16384):
|
||||
if cancel_event.is_set():
|
||||
response.close()
|
||||
self._status_queue.put("Stopped previous TTS stream.")
|
||||
return
|
||||
if not chunk:
|
||||
continue
|
||||
buffer.extend(chunk)
|
||||
if stream is None:
|
||||
header_end = _find_wav_data_offset(buffer)
|
||||
if header_end is None:
|
||||
continue
|
||||
channels, sample_rate, sample_width = _read_wav_format(buffer)
|
||||
stream = sd.OutputStream(samplerate=sample_rate, channels=channels, dtype="float32")
|
||||
stream.start()
|
||||
del buffer[:header_end]
|
||||
self._status_queue.put("Streaming Pocket TTS audio...")
|
||||
frame_size = sample_width * channels
|
||||
usable = len(buffer) - (len(buffer) % frame_size)
|
||||
if usable <= 0:
|
||||
continue
|
||||
pcm = bytes(buffer[:usable])
|
||||
del buffer[:usable]
|
||||
audio = _pcm_to_float32(pcm, sample_width, channels)
|
||||
if audio.size:
|
||||
stream.write(_apply_volume_gain(audio, self.volume_gain))
|
||||
finally:
|
||||
if stream is not None:
|
||||
stream.stop()
|
||||
stream.close()
|
||||
|
||||
|
||||
class _QuietStaticFileHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def log_message(self, _format: str, *args: object) -> None:
|
||||
return
|
||||
|
||||
|
||||
class _VoiceStateServer:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._server: http.server.ThreadingHTTPServer | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def url_for(self, path: Path) -> str:
|
||||
with self._lock:
|
||||
VOICE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
if self._server is None:
|
||||
handler = functools.partial(_QuietStaticFileHandler, directory=str(VOICE_CACHE_DIR))
|
||||
self._server = http.server.ThreadingHTTPServer(("127.0.0.1", 0), handler)
|
||||
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
|
||||
self._thread.start()
|
||||
port = self._server.server_address[1]
|
||||
return f"http://127.0.0.1:{port}/{quote(path.name)}"
|
||||
|
||||
|
||||
_voice_state_server = _VoiceStateServer()
|
||||
|
||||
|
||||
def _cached_voice_state_path(
|
||||
source_path: str,
|
||||
language: str,
|
||||
status_queue: queue.Queue[str],
|
||||
voice_name: str = "",
|
||||
) -> Path:
|
||||
source = Path(source_path)
|
||||
if not source.exists():
|
||||
raise FileNotFoundError(f"Voice file not found: {source}")
|
||||
VOICE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
stat = source.stat()
|
||||
digest = hashlib.sha256(
|
||||
f"{source.resolve()}|{stat.st_mtime_ns}|{stat.st_size}|{language}".encode("utf-8")
|
||||
).hexdigest()[:24]
|
||||
prefix = safe_voice_slug(voice_name) if voice_name.strip() else source.stem
|
||||
target = VOICE_CACHE_DIR / f"{safe_voice_slug(prefix)}-{digest}.safetensors"
|
||||
if source.suffix.casefold() == ".safetensors":
|
||||
if not target.exists():
|
||||
shutil.copy2(source, target)
|
||||
status_queue.put("Using cached custom voice state.")
|
||||
return target
|
||||
if target.exists():
|
||||
status_queue.put("Using cached custom voice state.")
|
||||
return target
|
||||
|
||||
prompt_source = _prepared_audio_prompt_path(source, language, status_queue, digest)
|
||||
status_queue.put("Exporting custom voice cache; first run can take a while.")
|
||||
command = [
|
||||
str(_find_uvx()),
|
||||
"pocket-tts",
|
||||
"export-voice",
|
||||
str(prompt_source),
|
||||
str(target),
|
||||
"--language",
|
||||
language,
|
||||
"--quiet",
|
||||
]
|
||||
subprocess.run(command, check=True, env=_clean_subprocess_env(), creationflags=_subprocess_creationflags())
|
||||
status_queue.put("Custom voice cache ready.")
|
||||
return target
|
||||
|
||||
|
||||
def _prepared_audio_prompt_path(
|
||||
source_path: str | Path,
|
||||
language: str,
|
||||
status_queue: queue.Queue[str],
|
||||
digest: str | None = None,
|
||||
) -> Path:
|
||||
source = Path(source_path)
|
||||
if source.suffix.casefold() != ".mp3":
|
||||
return source
|
||||
VOICE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
if digest is None:
|
||||
stat = source.stat()
|
||||
digest = hashlib.sha256(
|
||||
f"{source.resolve()}|{stat.st_mtime_ns}|{stat.st_size}|{language}".encode("utf-8")
|
||||
).hexdigest()[:24]
|
||||
target = VOICE_CACHE_DIR / f"{safe_voice_slug(source.stem)}-{digest}.wav"
|
||||
if target.exists():
|
||||
status_queue.put("Using cached WAV conversion for MP3 voice.")
|
||||
return target
|
||||
status_queue.put("Converting MP3 voice reference to WAV...")
|
||||
_convert_mp3_to_wav(source, target)
|
||||
status_queue.put("MP3 voice conversion ready.")
|
||||
return target
|
||||
|
||||
|
||||
def _convert_mp3_to_wav(source: Path, target: Path) -> None:
|
||||
try:
|
||||
import imageio_ffmpeg
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("MP3 custom voices require imageio-ffmpeg. Reinstall Seshat TTS dependencies.") from exc
|
||||
|
||||
command = [
|
||||
imageio_ffmpeg.get_ffmpeg_exe(),
|
||||
"-hide_banner",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-y",
|
||||
"-i",
|
||||
str(source),
|
||||
"-vn",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
"24000",
|
||||
"-sample_fmt",
|
||||
"s16",
|
||||
str(target),
|
||||
]
|
||||
subprocess.run(command, check=True, env=_clean_subprocess_env(), creationflags=_subprocess_creationflags())
|
||||
|
||||
|
||||
def _find_wav_data_offset(data: bytearray) -> int | None:
|
||||
marker = data.find(b"data")
|
||||
if marker < 0 or len(data) < marker + 8:
|
||||
return None
|
||||
return marker + 8
|
||||
|
||||
|
||||
def _read_wav_format(data: bytearray) -> tuple[int, int, int]:
|
||||
if len(data) < 36 or data[:4] != b"RIFF" or data[8:12] != b"WAVE":
|
||||
raise ValueError("Response is not a WAV stream.")
|
||||
fmt = data.find(b"fmt ")
|
||||
if fmt < 0 or len(data) < fmt + 24:
|
||||
raise ValueError("WAV stream is missing fmt chunk.")
|
||||
channels = int.from_bytes(data[fmt + 10 : fmt + 12], "little")
|
||||
sample_rate = int.from_bytes(data[fmt + 12 : fmt + 16], "little")
|
||||
bits_per_sample = int.from_bytes(data[fmt + 22 : fmt + 24], "little")
|
||||
return channels, sample_rate, bits_per_sample // 8
|
||||
|
||||
|
||||
def _pcm_to_float32(pcm: bytes, sample_width: int, channels: int) -> np.ndarray:
|
||||
if sample_width == 2:
|
||||
audio = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif sample_width == 4:
|
||||
audio = np.frombuffer(pcm, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported WAV sample width: {sample_width}")
|
||||
if channels > 1:
|
||||
return audio.reshape(-1, channels)
|
||||
return audio.reshape(-1, 1)
|
||||
|
||||
|
||||
def _clamp_volume_gain(value: float) -> float:
|
||||
return max(0.0, min(float(value), 3.0))
|
||||
|
||||
|
||||
def _apply_volume_gain(audio: np.ndarray, volume_gain: float) -> np.ndarray:
|
||||
gain = _clamp_volume_gain(volume_gain)
|
||||
if gain == 1.0:
|
||||
return audio
|
||||
return np.clip(audio * gain, -1.0, 1.0).astype(np.float32, copy=False)
|
||||
|
||||
|
||||
def _find_uvx() -> Path:
|
||||
bundled = resource_path("tools/uvx.exe")
|
||||
if bundled.exists():
|
||||
return bundled
|
||||
found = shutil.which("uvx")
|
||||
if found:
|
||||
return Path(found)
|
||||
candidates = [
|
||||
Path.home() / ".local" / "bin" / "uvx.exe",
|
||||
Path.home() / ".cargo" / "bin" / "uvx.exe",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
raise FileNotFoundError("uvx.exe was not found on PATH. Install uv or add uvx.exe to PATH.")
|
||||
|
||||
|
||||
def _clean_subprocess_env() -> dict[str, str]:
|
||||
env = os.environ.copy()
|
||||
for key in list(env):
|
||||
if key.startswith("_PYI") or key.startswith("PYINSTALLER"):
|
||||
env.pop(key, None)
|
||||
env.pop("PYTHONHOME", None)
|
||||
env.pop("PYTHONPATH", None)
|
||||
|
||||
user_bin = Path.home() / ".local" / "bin"
|
||||
if user_bin.exists():
|
||||
env["PATH"] = str(user_bin) + os.pathsep + env.get("PATH", "")
|
||||
return env
|
||||
|
||||
|
||||
def _subprocess_creationflags() -> int:
|
||||
if os.name != "nt":
|
||||
return 0
|
||||
return int(getattr(subprocess, "CREATE_NO_WINDOW", 0))
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import APP_DIR
|
||||
|
||||
|
||||
VOICE_PROFILES_PATH = APP_DIR / "voice_profiles.json"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class VoiceProfile:
|
||||
name: str
|
||||
path: str
|
||||
|
||||
|
||||
def safe_voice_slug(name: str) -> str:
|
||||
slug = re.sub(r"[^a-zA-Z0-9._-]+", "-", name.strip()).strip("-._")
|
||||
return slug or "custom-voice"
|
||||
|
||||
|
||||
def load_voice_profiles(path: Path = VOICE_PROFILES_PATH) -> list[VoiceProfile]:
|
||||
if not path.exists():
|
||||
return []
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
profiles: list[VoiceProfile] = []
|
||||
for item in data if isinstance(data, list) else []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
name = str(item.get("name", "")).strip()
|
||||
voice_path = str(item.get("path", "")).strip()
|
||||
if name and voice_path:
|
||||
profiles.append(VoiceProfile(name=name, path=voice_path))
|
||||
return profiles
|
||||
|
||||
|
||||
def save_voice_profiles(profiles: list[VoiceProfile], path: Path = VOICE_PROFILES_PATH) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps([asdict(profile) for profile in profiles], indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def upsert_voice_profile(profile: VoiceProfile, path: Path = VOICE_PROFILES_PATH) -> list[VoiceProfile]:
|
||||
profiles = [item for item in load_voice_profiles(path) if item.name != profile.name]
|
||||
profiles.append(profile)
|
||||
profiles.sort(key=lambda item: item.name.casefold())
|
||||
save_voice_profiles(profiles, path)
|
||||
return profiles
|
||||
|
||||
|
||||
def voice_profile_by_name(name: str, profiles: list[VoiceProfile]) -> VoiceProfile | None:
|
||||
return next((profile for profile in profiles if profile.name == name), None)
|
||||
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import win32gui
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WindowInfo:
|
||||
hwnd: int
|
||||
title: str
|
||||
left: int
|
||||
top: int
|
||||
right: int
|
||||
bottom: int
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.right - self.left
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.bottom - self.top
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return f"{self.title} [{self.width}x{self.height} at {self.left},{self.top}]"
|
||||
|
||||
|
||||
def _is_candidate(hwnd: int) -> bool:
|
||||
if not win32gui.IsWindowVisible(hwnd):
|
||||
return False
|
||||
title = win32gui.GetWindowText(hwnd).strip()
|
||||
if not title:
|
||||
return False
|
||||
left, top, right, bottom = win32gui.GetWindowRect(hwnd)
|
||||
return (right - left) > 50 and (bottom - top) > 50
|
||||
|
||||
|
||||
def list_visible_windows() -> list[WindowInfo]:
|
||||
windows: list[WindowInfo] = []
|
||||
|
||||
def callback(hwnd: int, _extra: object) -> None:
|
||||
if _is_candidate(hwnd):
|
||||
left, top, right, bottom = win32gui.GetWindowRect(hwnd)
|
||||
windows.append(
|
||||
WindowInfo(
|
||||
hwnd=hwnd,
|
||||
title=win32gui.GetWindowText(hwnd).strip(),
|
||||
left=left,
|
||||
top=top,
|
||||
right=right,
|
||||
bottom=bottom,
|
||||
)
|
||||
)
|
||||
|
||||
win32gui.EnumWindows(callback, None)
|
||||
windows.sort(key=lambda item: item.title.casefold())
|
||||
return windows
|
||||
|
||||
|
||||
def find_window_by_title(title: str) -> WindowInfo | None:
|
||||
title = title.strip()
|
||||
if not title:
|
||||
return None
|
||||
for window in list_visible_windows():
|
||||
if window.title == title:
|
||||
return window
|
||||
needle = title.casefold()
|
||||
return next((window for window in list_visible_windows() if needle in window.title.casefold()), None)
|
||||
Reference in New Issue
Block a user