added image model and loading/saving methods

This commit is contained in:
aj 2019-12-08 12:02:15 +00:00
parent 284519c51d
commit e19f13bbe5
9 changed files with 288 additions and 3 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
scratch.py
descriptors
node_modules/

File diff suppressed because one or more lines are too long

44
vision.ipynb Normal file
View File

@ -0,0 +1,44 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### vision"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import vision.io\n",
"from vision.descriptor.avg_RGB import extract_average_rgb\n",
"\n",
"x = vision.io.load_msrc('msrc/Images')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

4
vision/__init__.py Normal file
View File

@ -0,0 +1,4 @@
import logging
logger = logging.getLogger(__name__)
logger.setLevel('DEBUG')

View File

View File

@ -0,0 +1,7 @@
from vision.model import Image
from typing import List
def extract_average_rgb(images: List[Image]):
for image in images:
image.descriptor = image.mean(axis=(0, 1))

95
vision/io/__init__.py Normal file
View File

@ -0,0 +1,95 @@
import glob
import os
import pickle
import logging
from typing import List
import cv2
from vision.model import Image
logger = logging.getLogger(__name__)
def load_path(path: str) -> List[str]:
if not os.path.exists(path):
logger.error(f'folder {path} does not exist')
raise FileNotFoundError('path does not exist')
files = []
for extension_set in [glob.glob(f'{path}/*.%s' % ext) for ext in ["jpg", "bmp", "png"]]:
if len(extension_set) > 0:
files += extension_set
return files
def load_set(path: str) -> List[Image]:
logger.info(f'loading set from {path}')
files = load_path(path)
images = [Image(cv2.imread(i))[:, :, ::-1] for i in files]
return images
def load_msrc(path: str, descriptor_path=None) -> List[Image]:
logger.info(f'loading msrc from {path}, descriptor path {descriptor_path}')
files = load_path(path)
images = []
for image in files:
file_name = image.split('/')[-1]
file_name_split = file_name.split('_')
category = int(file_name_split[0])
name = int(file_name_split[1])
images.append(Image(cv2.imread(image)[:, :, ::-1],
category=category,
name=name))
if descriptor_path is not None:
load_descriptors(descriptor_path, images)
return images
def save_descriptors(images: List[Image], path: str = 'descriptors/default'):
logger.info(f'saving {len(images)} descriptors to {path}')
if not os.path.exists(path):
os.makedirs(path)
counter = 0
for image in images:
if image.name is not None and image.category is not None:
name = f'{image.category}_{image.name}'
else:
name = f'{counter}'
counter += 1
with open(os.path.join(path, name), 'wb') as file:
pickle.dump(image.descriptor, file)
def load_descriptors(path: str = 'descriptors/default', images: List[Image] = None):
logger.info(f'loading descriptors from {path}, {len(images)} images')
if not os.path.exists(path):
logger.error(f'folder {path} does not exist')
raise FileNotFoundError('folder does not exist')
descriptors = []
for file_name in [i for i in os.listdir(path) if os.path.isfile(os.path.join(path, i))]:
with open(os.path.join(path, file_name), 'rb') as file:
descriptor = pickle.load(file)
if images is not None:
desc_cat = int(file_name.split('_')[0])
desc_name = int(file_name.split('_')[1])
image = next((i for i in images if i.category == desc_cat and i.name == desc_name), None)
if image is not None:
image.descriptor = descriptor
else:
logger.error(f'no corresponding image found to hold descriptor {file_name}')
else:
descriptors.append(descriptor)
return descriptors

83
vision/model/__init__.py Normal file
View File

@ -0,0 +1,83 @@
import numpy as np
class Image:
def __init__(self,
pixels: np.array,
category=None,
name=None,
descriptor=None):
self.pixels = pixels
self.category = category
self.name = name
self.descriptor = descriptor
@property
def shape(self):
return self.pixels.shape
@property
def height(self):
return self.pixels.shape[0]
@property
def width(self):
return self.pixels.shape[1]
@property
def T(self):
return self.pixels.T
@property
def flat(self):
return self.pixels.flat
@property
def max(self):
return self.pixels.max()
@property
def min(self):
return self.pixels.min()
def round(self, decimals):
return self.pixels.round(decimals)
def sum(self, axis=None):
if axis is not None:
return self.pixels.sum(axis=axis)
else:
return self.pixels.sum()
def mean(self, axis=None):
if axis is not None:
return self.pixels.mean(axis=axis)
else:
return self.pixels.mean()
def __add__(self, other):
return self.pixels + other
def __sub__(self, other):
return self.pixels - other
def __mul__(self, other):
return self.pixels * other
def __truediv__(self, other):
return self.pixels / other
def __floordiv__(self, other):
return self.pixels // other
def __mod__(self, other):
return self.pixels % other
def __pow__(self, other):
return pow(self.pixels, other)
def __eq__(self, other):
return isinstance(other, Image) and self.pixels == other.pixels
def __repr__(self):
return f'Image: {self.shape} ({self.descriptor})'

7
vision/util/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from vision.model import Image
from typing import List
import numpy as np
def get_category_histogram(images: List[Image], bins: int):
return np.histogram([i.category for i in images], bins=bins)