from __future__ import annotations
import logging
import operator
from typing import Any, Iterable, List, Optional, Sequence
from sample_id.fingerprint import Fingerprint
from . import Matcher, MatcherMetadata
from .query import Match
logger = logging.getLogger(__name__)
# TODO: Refactor, maybe to separate training from inference, because this should not be trainable
[docs]class HiveMatcher(Matcher):
"""A wrapper around a list of Matchers so that they act like a single Matcher (for inference only)."""
def __init__(self, matchers: List[Matcher]):
sr = next((matcher.meta.sr for matcher in matchers), None)
hop_length = next((matcher.meta.hop_length for matcher in matchers), None)
for matcher in matchers:
if matcher.meta.sr != sr or matcher.meta.hop_length != hop_length:
raise ValueError(f"Hive must all have the same sr and hop_length, can't add {matcher}")
self.meta = MatcherMetadata(sr=sr, hop_length=hop_length)
self.matchers = matchers
[docs] def add_matcher(self, matcher: Matcher) -> HiveMatcher:
if self.meta.sr is None:
self.meta.sr = matcher.meta.sr
if self.meta.hop_length is None:
self.meta.hop_length = matcher.meta.hop_length
if matcher.meta.sr != self.meta.sr or matcher.meta.hop_length != self.meta.hop_length:
raise ValueError(f"Hive must all have the same sr and hop_length, can't add {matcher}")
self.matchers.append(matcher)
return self
[docs] def init_model(self) -> Any:
raise NotImplementedError(f"Don't do this.")
[docs] def save_model(self, filepath: str, **kwargs) -> str:
raise NotImplementedError(f"Don't do this.")
[docs] def load_model(self, filepath: str, **kwargs) -> Any:
raise NotImplementedError(f"Don't do this.")
[docs] def add_fingerprint(self, fingerprint: Fingerprint, dedupe=True) -> Matcher:
raise NotImplementedError(f"Don't do this.")
[docs] def add_fingerprints(self, fingerprints: Iterable[Fingerprint], **kwargs) -> Matcher:
raise NotImplementedError(f"Don't do this.")
[docs] def can_add_fingerprint(self, fingerprint: Fingerprint) -> bool:
return False
[docs] def save(self, filepath: str, compress: bool = True, **kwargs) -> str:
raise NotImplementedError(f"Don't do this, save the matchers individually.")
[docs] @classmethod
def create(cls, sr: Optional[int] = None, hop_length: Optional[int] = None, **kwargs) -> Matcher:
raise NotImplementedError(f"Don't do this.")
[docs] @classmethod
def from_fingerprint(cls, fp: Fingerprint, **kwargs) -> Matcher:
raise NotImplementedError(f"Don't do this.")
[docs] @classmethod
def from_fingerprints(cls, fingerprints: Sequence[Fingerprint], **kwargs) -> Matcher:
raise NotImplementedError(f"Don't do this.")
[docs] @classmethod
def load(cls, filepaths: Iterable[str], **kwargs) -> Matcher:
"""Load multiple matchers from disk into hive."""
matchers = []
for filepath in filepaths:
matcher = Matcher.load(filepath, **kwargs)
matchers.append(matcher)
return cls(matchers)
[docs] def nearest_neighbors(self, fp: Fingerprint, k: int = 1) -> Sequence[Match]:
hive_matches = []
for matcher in self.matchers:
matches = matcher.nearest_neighbors(fp, k)
hive_matches.append(matches)
resorted_matches = []
for kp_hive_matches in zip(*hive_matches):
top_k_neighbors = sorted(
(neighbor for match in kp_hive_matches for neighbor in match.neighbors),
key=operator.attrgetter("distance"),
)[:k]
head = next(match for match in kp_hive_matches)
resorted_matches.append(Match(head.keypoint, top_k_neighbors))
return resorted_matches