47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
|
from typing import List
|
||
|
import random
|
||
|
import numpy as np
|
||
|
from vision.model import Image
|
||
|
import vision.maths.precision_recall as pr
|
||
|
import logging
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class QueryResult:
|
||
|
def __init__(self,
|
||
|
sorted_images: List[Image],
|
||
|
query_image: Image,
|
||
|
precision_recall: pr.PrecisionRecall):
|
||
|
self.sorted_images = sorted_images
|
||
|
self.query_image = query_image
|
||
|
self.precision_recall = precision_recall
|
||
|
|
||
|
|
||
|
def run_query(images: List[Image], distance_measure=None, query_index=None):
|
||
|
logger.info(f'running query on {len(images)} images, query index {query_index}')
|
||
|
|
||
|
if query_index is not None:
|
||
|
query_image = images[query_index]
|
||
|
else:
|
||
|
query_image = random.choice(images)
|
||
|
|
||
|
if any(i for i in images if i.descriptor is None):
|
||
|
raise ValueError('descriptors required for all images')
|
||
|
|
||
|
for image in images:
|
||
|
if distance_measure is None:
|
||
|
image.distance = np.linalg.norm(image.descriptor-query_image.descriptor)
|
||
|
else:
|
||
|
image.distance = distance_measure(image.descriptor - query_image.descriptor)
|
||
|
|
||
|
images = [i for i in images if not (i.category == query_image.category and i.name == query_image.name)]
|
||
|
|
||
|
query_pr = pr.get_pr(images, query=query_image)
|
||
|
|
||
|
results = QueryResult(sorted_images=images,
|
||
|
query_image=query_image,
|
||
|
precision_recall=query_pr)
|
||
|
logger.info(f'query finished AP: {results.precision_recall.ap}')
|
||
|
return results
|