from __future__ import annotations
import abc
import bisect
import dataclasses
import datetime
import itertools
import logging
import math
import os
import statistics
import tempfile
from collections import defaultdict
from dataclasses import InitVar, dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
from sample_id import util
from sample_id.fingerprint import Fingerprint
from . import query
logger = logging.getLogger(__name__)
MATCHER_FILENAME: str = "matcher.ann"
META_FILENAME: str = "meta.npz"
# TODO: Make this a proper interface, for now just implementing annoy
[docs]class Matcher(abc.ABC):
"""Nearest neighbor matcher that may use one of various implementations under the hood."""
def __init__(self, metadata: MatcherMetadata):
self.index = 0
self.num_items = 0
self.meta = metadata
self.model = self.init_model()
[docs] @abc.abstractmethod
def init_model(self) -> Any:
"""Initialize the model."""
pass
[docs] @abc.abstractmethod
def save_model(self, filepath: str, **kwargs) -> str:
"""Save this matcher's model to disk."""
pass
[docs] @abc.abstractmethod
def load_model(self, filepath: str, **kwargs) -> Any:
"""Load this matcher's model from disk."""
pass
[docs] @abc.abstractmethod
def nearest_neighbors(self, fp: Fingerprint, k: int = 1) -> Iterable[query.Match]:
"""Fetch nearest neighbors to this fingerprint's keypoints."""
pass
[docs] def add_fingerprint(self, fingerprint: Fingerprint, dedupe=True) -> Matcher:
"""Add a Fingerprint to the matcher."""
if self.can_add_fingerprint(fingerprint):
if dedupe and not fingerprint.is_deduped:
fingerprint.remove_similar_keypoints()
logger.info(f"Adding {fingerprint} to index.")
self.meta.index_to_id = np.hstack([self.meta.index_to_id, fingerprint.keypoint_index_ids()])
# self.meta.index_to_ms = np.hstack([self.meta.index_to_ms, fingerprint.keypoint_index_ms()])
self.meta.index_to_kp = np.vstack([self.meta.index_to_kp, fingerprint.keypoints])
for descriptor in fingerprint.descriptors:
self.model.add_item(self.index, descriptor)
self.index += 1
self.num_items += 1
return self
[docs] def add_fingerprints(self, fingerprints: Iterable[Fingerprint], **kwargs) -> Matcher:
"""Add Fingerprints to the matcher."""
for fingerprint in fingerprints:
self.add_fingerprint(fingerprint, **kwargs)
return self
[docs] def can_add_fingerprint(self, fingerprint: Fingerprint) -> bool:
"""Check if fingerprint can be added to matcher."""
if not self.meta.sr:
self.meta.sr = fingerprint.sr
if not self.meta.hop_length:
self.meta.hop_length = fingerprint.hop_length
if self.meta.sr != fingerprint.sr:
logger.warn(f"Can't add fingerprint with sr={fingerprint.sr}, must equal matcher sr={self.meta.sr}")
if self.meta.hop_length != fingerprint.hop_length:
logger.warn(
f"Can't add fingerprint with hop_length={fingerprint.hop_length}, must equal matcher hop_length={self.meta.hop_length}"
)
return True
[docs] def save(
self,
filepath: str,
compress: bool = True,
compress_level: int = 9,
blocksize: int = 10 * 1024 * 1024,
threads: Optional[int] = None,
**kwargs,
) -> str:
"""Save this matcher to disk."""
with tempfile.NamedTemporaryFile(suffix=".tar") as tmp_tarf:
with tempfile.TemporaryDirectory() as tmpdir:
logger.info(f"Saving {self} to temporary dir: {tmpdir}")
tmp_model_path = os.path.join(tmpdir, MATCHER_FILENAME)
tmp_meta_path = os.path.join(tmpdir, META_FILENAME)
tmp_model_path = self.save_model(tmp_model_path, **kwargs)
self.meta.save(tmp_meta_path, compress=compress)
logger.debug(f"Model file {tmp_model_path} size: {util.filesize(tmp_model_path)}")
logger.debug(f"Metadata file {tmp_meta_path} size: {util.filesize(tmp_meta_path)}")
util.tar_files(tmp_tarf.name, [tmp_model_path, tmp_meta_path], [MATCHER_FILENAME, META_FILENAME])
logger.debug(f"Tar file {tmp_tarf.name} size: {util.filesize(tmp_tarf.name)}")
logger.info(f"Zipping {tmp_tarf.name} into {filepath}")
util.gzip_file(filepath, tmp_tarf.name, compress_level=compress_level, blocksize=blocksize, threads=threads)
logger.info(f"Zipped file {filepath} size: {util.filesize(filepath)}")
return filepath
[docs] def unload(self) -> None:
"""Unload things from memory and cleanup any temporary files."""
self.model.unload()
if "tempdir" in vars(self):
self.tempdir.cleanup()
[docs] @classmethod
def create(cls, sr: Optional[int] = None, hop_length: Optional[int] = None, **kwargs) -> Matcher:
"""Create an instance, pass any kwargs needed by the subclass."""
meta = MatcherMetadata(sr=sr, hop_length=hop_length, **kwargs)
return cls(meta)
[docs] @classmethod
def from_fingerprint(cls, fp: Fingerprint, **kwargs) -> Matcher:
"""Useful for determining metadata for the Matcher based on the data being added."""
matcher = cls.create(sr=fp.sr, hop_length=fp.hop_length, n_features=fp.descriptors.shape[1], **kwargs)
return matcher.add_fingerprint(fp, **kwargs)
[docs] @classmethod
def from_fingerprints(cls, fingerprints: Sequence[Fingerprint], **kwargs) -> Matcher:
"""My data is small, just create and train the entire matcher."""
fp = fingerprints[0]
matcher = cls.create(sr=fp.sr, hop_length=fp.hop_length, n_features=fp.descriptors.shape[1], **kwargs)
return matcher.add_fingerprints(fingerprints, **kwargs)
[docs] @classmethod
def load(cls, filepath: str, blocksize: int = 10 * 1024 * 1024, threads: Optional[int] = None, **kwargs) -> Matcher:
"""Load a matcher from disk."""
with tempfile.NamedTemporaryFile(suffix=".tar") as tmp_tarf:
logger.debug(f"Unzipping {filepath} to {tmp_tarf.name}...")
util.gunzip_file(filepath, tmp_tarf.name, blocksize=blocksize, threads=threads)
tempdir = tempfile.TemporaryDirectory()
tmp_model_path = os.path.join(tempdir.name, MATCHER_FILENAME)
tmp_meta_path = os.path.join(tempdir.name, META_FILENAME)
util.untar(tmp_tarf.name, [MATCHER_FILENAME, META_FILENAME], tempdir.name)
meta = MatcherMetadata.load(tmp_meta_path)
matcher = cls(meta)
matcher.tempdir = tempdir
matcher.load_model(tmp_model_path, **kwargs)
return matcher
def __repr__(self):
return f"{self.__class__.__name__}({self.meta})"
[docs] def filter_matches(
self,
matches: List[query.Match],
abs_thresh: Optional[float] = 0.25,
ratio_thresh: Optional[float] = None,
cluster_dist: float = 4.0,
cluster_size: int = 2,
match_orientation: bool = True,
ordered: bool = False,
cluster_filter: Optional[Callable[[query.Cluster], bool]] = None,
) -> List[query.Cluster]:
cluster_sample_dist = int(cluster_dist * self.meta.sr / self.meta.hop_length)
return query.filter_matches(
matches,
abs_thresh=abs_thresh,
ratio_thresh=ratio_thresh,
cluster_dist=cluster_sample_dist,
cluster_size=cluster_size,
match_orientation=match_orientation,
ordered=ordered,
cluster_filter=cluster_filter,
)
[docs] def find_samples(
self,
fp: Fingerprint,
k: int = 1,
abs_thresh: Optional[float] = 0.25,
ratio_thresh: Optional[float] = None,
cluster_dist: float = 20.0,
cluster_size: int = 2,
match_orientation: bool = True,
ordered: bool = False,
cluster_filter: Optional[Callable[[query.Cluster], bool]] = None,
) -> query.Result:
matches = self.nearest_neighbors(fp, k)
clusters = self.filter_matches(
matches,
abs_thresh=abs_thresh,
ratio_thresh=ratio_thresh,
cluster_dist=cluster_dist,
cluster_size=cluster_size,
match_orientation=match_orientation,
ordered=ordered,
cluster_filter=cluster_filter,
)
return query.Result(fp, clusters)