Source code for sample_id.ann.annoy

import logging
from typing import Any, Sequence

import annoy

from sample_id.fingerprint import Fingerprint, Keypoint

from . import Matcher, MatcherMetadata
from .query import Match, Neighbor

logger = logging.getLogger(__name__)


[docs]class AnnoyMatcher(Matcher): """Nearest neighbor matcher using annoy.""" def __init__(self, metadata: MatcherMetadata): metadata.metric = vars(metadata).get("metric", "angular") metadata.n_features = vars(metadata).get("n_features", 128) metadata.n_trees = vars(metadata).get("n_trees", 40) metadata.n_jobs = vars(metadata).get("n_jobs", -1) super().__init__(metadata) self.on_disk = None self.built = False
[docs] def init_model(self) -> Any: logger.info(f"Initializing Annoy Index with {self.meta}...") return annoy.AnnoyIndex(self.meta.n_features, metric=self.meta.metric)
[docs] def save_model(self, filepath: str, prefault: bool = False) -> str: if not self.built: self.build() else: logger.info(f"Annoy Index already built.") if self.on_disk: logger.info(f"Annoy index already built_on_disk at {self.on_disk}.") return self.on_disk logger.info(f"Saving matcher model to {filepath}...") self.model.save(filepath, prefault=prefault) return filepath
[docs] def load_model(self, filepath: str, prefault: bool = False) -> None: logger.info(f"Loading Annoy Index from {filepath}...") self.model.load(filepath, prefault=prefault) self.built = True return self.model
[docs] def build(self) -> None: logger.info(f"Building Annoy Index with {self.meta}...") self.model.build(self.meta.n_trees, self.meta.n_jobs) self.built = True
[docs] def on_disk_build(self, filename: str) -> None: logger.info(f"Building Annoy Index straight to disk: {filename}...") self.model.on_disk_build(filename) self.on_disk = filename
[docs] def nearest_neighbors(self, fp: Fingerprint, k: int = 1) -> Sequence[Match]: matches = [] for kp, desc in zip(fp.keypoints, fp.descriptors): indices, distances = self.model.get_nns_by_vector(desc, k, include_distances=True) kp_neighbors = [Neighbor(index, distance, self.meta) for index, distance in zip(indices, distances)] matches.append(Match(Keypoint(kp), kp_neighbors)) return matches