computer-vision/vision/visualsearch/__init__.py

45 lines
1.4 KiB
Python
Raw Normal View History

2021-01-29 22:05:37 +00:00
from dataclasses import dataclass
from typing import List
import random
import numpy as np
from vision.model import Image
import vision.maths.precision_recall as pr
import logging
logger = logging.getLogger(__name__)
2021-01-29 22:05:37 +00:00
@dataclass
class QueryResult:
2021-01-29 22:05:37 +00:00
sorted_images: List[Image]
query_image: Image
precision_recall: pr.PrecisionRecall
def run_query(images: List[Image], distance_measure=None, query_index=None):
logger.info(f'running query on {len(images)} images, query index {query_index}')
if query_index is not None:
query_image = images[query_index]
else:
query_image = random.choice(images)
if any(i for i in images if i.descriptor is None):
raise ValueError('descriptors required for all images')
for image in images:
if distance_measure is None:
image.distance = np.linalg.norm(image.descriptor-query_image.descriptor)
else:
image.distance = distance_measure(image.descriptor - query_image.descriptor)
images = [i for i in images if not (i.category == query_image.category and i.name == query_image.name)]
query_pr = pr.get_pr(images, query=query_image)
results = QueryResult(sorted_images=images,
query_image=query_image,
precision_recall=query_pr)
logger.info(f'query finished AP: {results.precision_recall.ap}')
return results