Image Database

基于VGG16的图片特征数据库

项目起源于想在本地3tb图库中快速搜索色图,先是用python(keras、numpy、h5py、matplotlib)实现了一个试验性的版本,觉得python太慢了,决定用C++重构(还是C++写得爽,按照惯例,支持win、linux、osx。

项目使用VGG16模型,模型使用MMdnn下载,为了达到只提取特征而不进行识别的目的,对模型进行修改,去掉顶层的3个fully connected,将soft max换成global max pool。

项目主要使用opencv来完成计算,是我目前依赖最多的一个项目,也是第一个使用VS Code开发的项目,也是第一个在OSX上开发的项目(Apple Clang 8行。依赖:

  • eigen3 用于矩阵运算
  • opencv4[dnn,eigen,jpeg,png,quirc,tiff,webp] 读取图片文件,计算vgg16
  • ffmpeg[avcodec,avformat,swscale] 读取opencv不支持的文件
  • libzip 读取压缩文件
  • boost-context 使用call/cc
  • (Linux/Mac) tbb 多线程查询

后续改进:打算直接去掉OpenCV,使用FFmpeg读取图片,使用caffe计算VGG16,再后面就直接去掉caffe,直接写cuda或者是boost对cuda的封装。

用法:ImageDatabase 操作 选项

操作:build,构建数据库;query,查询数据库

选项:

  • -d 数据库路径
  • -i 输入路径
  • –loglevel (可选,默认Info)日志级别 [None, Error, Warn, Log, Info, Debug]
  • –logfile (可选)日志文件

main.cpp

#include <unordered_map>
#include <filesystem>
#include <utility>
#include <fstream>
#include <exception>
#include <optional>
#include <thread>
#include <string>

#include <zip.h>
#include <boost/context/continuation.hpp>

extern "C" {
#include <libavformat/avformat.h>
#include <libavcodec/avcodec.h>
#include <libavutil/avutil.h>
#include <libavutil/imgutils.h>
#include <libavutil/pixdesc.h>
#include <libswscale/swscale.h>
}

#include "Arguments.h"
#include "Convert.h"
#include "ImageDatabase.h"
#include "StdIO.h"
#include "Thread.h"
#include "Log.h"
#include "Algorithm.h"

ArgumentOption(Operator, build, query)

#define CatchEx

static Logger<std::wstring> Log;

int main(int argc, char* argv[])
{
#define ArgumentsFunc(arg) [&](decltype(arg)::ConvertFuncParamType value) -> decltype(arg)::ConvertResult
	
	ArgumentsParse::Argument<Operator> opArg
	{
		"",
		"operator " + OperatorDesc(),
		ArgumentsFunc(opArg) { return { *ToOperator(std::string(value)), {} }; }
	};
	ArgumentsParse::Argument<std::filesystem::path> dbArg
	{
		"-d",
		"database path"
	};
	ArgumentsParse::Argument<std::filesystem::path> pathArg
	{
		"-i",
		"input path"
	};
	ArgumentsParse::Argument<LogLevel> logLevelArg
	{
		"--loglevel",
		"log level " + LogLevelDesc(ToString(LogLevel::Info)),
		LogLevel::Info,
		ArgumentsFunc(logLevelArg) { return { *ToLogLevel(std::string(value)), {} };	}
	};
	ArgumentsParse::Argument<std::filesystem::path> logFileArg
	{
		"--logfile",
		"log level " + LogLevelDesc(ToString(LogLevel::Info)),
		""
	};
#undef ArgumentsFunc
	
	ArgumentsParse::Arguments args;
	args.Add(opArg);
	args.Add(dbArg);
	args.Add(pathArg);
	args.Add(logLevelArg);
	args.Add(logFileArg);

#ifdef CatchEx
	try
#endif
	{
		std::thread logThread;
		args.Parse(argc, argv);

		logThread = std::thread([](const LogLevel& level, std::filesystem::path logFile)
		{
			cv::redirectError([](const int status, const char* funcName, const char* errMsg,
			                     const char* fileName, const int line, void*)
			{
				auto msg =
					L"[" +
						*Convert::ToWString(fileName) + L":" +
						*Convert::ToWString(funcName) + L":" +
						*Convert::ToWString(line) +
					L"] " +
					L"status: " + *Convert::ToWString(status) + L": " +
					*Convert::ToWString(errMsg);
				if (const auto pos = msg.find_last_of(L'\n'); pos != std::wstring::npos) msg.erase(msg.find_last_of(L'\n'));
				Log.Write<LogLevel::Error>(msg);
				return 0;
			});

			av_log_set_level([&]()
			{
				switch (level)
				{
				case LogLevel::None:
					return AV_LOG_QUIET;
				case LogLevel::Error:
					return AV_LOG_ERROR;
				case LogLevel::Warn:
					return AV_LOG_WARNING;
				case LogLevel::Debug:
					return AV_LOG_VERBOSE;
				default:
					return AV_LOG_INFO;
				}
			}());
			av_log_set_callback([](void* avc, const int ffLevel, const char* fmt, const va_list vl)
			{
				static char buf[4096]{ 0 };
				int ret = 1;
				av_log_format_line2(avc, ffLevel, fmt, vl, buf, 4096, &ret);
				auto data = *Convert::ToWString(static_cast<char*>(buf));
				if (const auto pos = data.find_last_of(L'\n'); pos != std::wstring::npos) data.erase(pos);
				if (ffLevel <= 16) Log.Write<LogLevel::Error>(data);
				else if (ffLevel <= 24) Log.Write<LogLevel::Warn>(data);
				else if (ffLevel <= 32) Log.Write<LogLevel::Log>(data);
				else Log.Write<LogLevel::Debug>(data);
			});

			Log.level = level;
			std::ofstream fs;
			if (!logFile.empty())
			{
				fs.open(logFile);
				const auto buf = std::make_unique<char[]>(4096);
				fs.rdbuf()->pubsetbuf(buf.get(), 4096);
				if (!fs)
				{
					Log.Write<LogLevel::Error>(L"log file: " + logFile.wstring() + L": open fail");
					logFile = "";
				}
			}

			while (true)
			{
				const auto [level, msg] = Log.Chan.Read();
				std::string out;
				String::StringCombine(out, "[", ToString(level), "] ");
				auto utf8 = false;
				try
				{
					String::StringCombine(out, std::filesystem::path(msg).string());
				}
				catch (...)
				{
					utf8 = true;
					String::StringCombine(out, std::filesystem::path(msg).u8string());
				}
				Console::WriteLine(out);
				if (!logFile.empty())
				{
					fs << (utf8 ? out : std::filesystem::path(out).u8string()) << std::endl;
					fs.flush();
				}

				if (level == LogLevel::None)
				{
					fs.close();
					break;
				}
			}
		}, args.Value(logLevelArg), args.Value(logFileArg));

		try
		{
			std::unordered_map<Operator, std::function<void()>>
			{
				{Operator::build, [&]()
				{
					const auto dbPath = args.Value(dbArg);
					const auto buildPath = args.Value(pathArg);
					ImageDatabase db(dbPath);

					ImageDatabase::Image img;

					Log.Write<LogLevel::Debug>([]()
					{
						static void* iterateData = nullptr;
						const auto encoderGetNextCodecName = []()->std::string
						{
							auto currentCodec = av_codec_iterate(&iterateData);
							while (currentCodec != nullptr)
							{
								if (!av_codec_is_decoder(currentCodec) && currentCodec->type == AVMEDIA_TYPE_VIDEO)
								{
									currentCodec = av_codec_iterate(&iterateData);
									continue;
								}
								return currentCodec->name;
							}
							return "";
						};
						std::wstring buf;
						auto dt = encoderGetNextCodecName();
						while (!dt.empty())
						{
							buf.append(*Convert::ToWString(dt.c_str()) + L" ");
							dt = encoderGetNextCodecName();
						}
						return buf;
					}());

					boost::context::continuation source = boost::context::callcc(
						[&](boost::context::continuation&& sink)
						{
							for (const auto& file : std::filesystem::recursive_directory_iterator(buildPath))
							{
								if (file.is_regular_file())
								{
									const auto& filePath = file.path();
									Log.Write<LogLevel::Log>(L"Scan file: " + filePath.wstring());
									if (filePath.extension() == ".zip")
									{
										const auto zipFile = *OpenCvUtility::ReadToEnd(file.path());

										zip_error_t error;
										zip_source_t* src = zip_source_buffer_create(zipFile.data(), zipFile.length(), 0, &error);
										if (src == nullptr && error.zip_err != ZIP_ER_OK)
										{
											Log.Write<LogLevel::Error>(std::wstring(L"load file: ") + file.path().wstring() + std::filesystem::path(error.str).wstring());
											continue;
										}

										zip_t* za = zip_open_from_source(src, ZIP_RDONLY, &error);
										if (za == nullptr && error.zip_err != ZIP_ER_OK)
										{
											Log.Write<LogLevel::Error>(L"load file: " + file.path().wstring() + *Convert::ToWString(error.zip_err));
											continue;
										}

										const auto entries = zip_get_num_entries(za, 0);
										if (entries < 0)
										{
											Log.Write<LogLevel::Error>(L"load file: " + *Convert::ToWString(zip_get_error(za)->str));
											zip_close(za);
											continue;
										}

										for (int i = 0; i < entries; ++i)
										{
											struct zip_stat zs;
											if (zip_stat_index(za, i, 0, &zs) == 0)
											{
												if (const std::string_view filename(zs.name); filename[filename.length() - 1] != '/')
												{
													auto* const zf = zip_fopen_index(za, i, 0);
													if (zf == nullptr
														&& zip_get_error(za)->zip_err != ZIP_ER_OK)
													{
														Log.Write<LogLevel::Error>(L"load file: zip_fopen_index: fail");
														continue;
													}
													const auto buf = std::make_unique<char[]>(zs.size);
													if (zip_fread(zf, buf.get(), zs.size) < 0)
													{
														Log.Write<LogLevel::Error>(L"load file: zip_fread: fail");
														zip_fclose(zf);
														continue;
													}
													img = ImageDatabase::Image(file.path() / filename, std::string(buf.get(), zs.size));
													zip_fclose(zf);

													sink = sink.resume();
												}
											}
										}
										zip_close(za);
									}
									else if (const auto ext = filePath.extension();
										ext == ".gif" ||
										ext == ".mp4" ||
										ext == ".mkv" ||
										ext == ".flv" ||
										ext == ".avi" ||
										ext == ".mpg" ||
										ext == ".vob" ||
										ext == ".mov" ||
										ext == ".wmv" ||
										ext == ".swf" ||
										ext == ".3gp" ||
										ext == ".mts" ||
										ext == ".rm" ||
										ext == ".ts" ||
										ext == ".m2ts" ||
										ext == ".rmvb" ||
										ext == ".mpeg" ||
										ext == ".webm")
									{
										std::string gifPath = filePath.u8string();

										AVFormatContext* fmtCtx = nullptr;
										AVCodec* codec = nullptr;
										AVCodecContext* codecCtx = nullptr;
										AVFrame* frame = nullptr;
										AVFrame* decFrame = nullptr;
										SwsContext* swsCtx = nullptr;

										try
										{
											int ret;
											ret = avformat_open_input(&fmtCtx, gifPath.c_str(), nullptr, nullptr);
											if (ret < 0) throw std::runtime_error("avforamt_open_input fail: " + *Convert::ToString(ret));

											ret = avformat_find_stream_info(fmtCtx, nullptr);
											if (ret < 0) throw std::runtime_error("avformat_find_stream_info fail: " + *Convert::ToString(ret));

											ret = av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, &codec, 0);
											if (ret < 0) throw std::runtime_error("av_find_best_stream fail: " + *Convert::ToString(ret));

											const int streamId = ret;
											auto codecParams = fmtCtx->streams[streamId]->codecpar;

											codecCtx = avcodec_alloc_context3(codec);
											if (!codecCtx) throw std::runtime_error("avcodec_alloc_context3 fail");

											avcodec_parameters_to_context(codecCtx, codecParams);
											ret = avcodec_open2(codecCtx, codec, nullptr);
											if (ret < 0) throw std::runtime_error("avcodec_open2 fail: " + *Convert::ToString(ret));

											const int dstWidth = codecParams->width;
											const int dstHeight = codecParams->height;
											const AVPixelFormat dstPixFmt = AV_PIX_FMT_BGR24;
											if (codecCtx->pix_fmt != AV_PIX_FMT_NONE)
											{
												swsCtx = sws_getCachedContext(
													nullptr, codecParams->width, codecParams->height, codecCtx->pix_fmt,
													dstWidth, dstHeight, dstPixFmt, 0, nullptr, nullptr, nullptr);
												if (!swsCtx) throw std::runtime_error("sws_getCachedContext fail");
											}

											frame = av_frame_alloc();
											std::string frameBuf(av_image_get_buffer_size(dstPixFmt, dstWidth, dstHeight, dstWidth), 0);
											av_image_fill_arrays(
												frame->data, frame->linesize,
												reinterpret_cast<uint8_t*>(frameBuf.data()),
												dstPixFmt, dstWidth, dstHeight, dstWidth);
						
											decFrame = av_frame_alloc();
											bool eof = false;
											AVPacket* pkt = av_packet_alloc();
											do
											{
												if (!eof)
												{
													ret = av_read_frame(fmtCtx, pkt);
													if (ret < 0 && ret != AVERROR_EOF) throw std::runtime_error("av_read_frame fail: " + *Convert::ToString(ret));
													
													if (ret == 0 && pkt->stream_index != streamId)
													{
														av_packet_unref(pkt);
														continue;
													}
													eof = (ret == AVERROR_EOF);
												}
												if (eof)
												{
													av_packet_unref(pkt);
													ret = 0;
												}
												else {ret = avcodec_send_packet(codecCtx, pkt);
												if (ret < 0) throw std::runtime_error("avcodec_send_packet: error sending a packet for decoding: " + *Convert::ToString(ret));
												}
												while (ret >= 0)
												{
													ret = avcodec_receive_frame(codecCtx, decFrame);
													if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF)
													{
														av_packet_unref(pkt);
														break;
													}
													if (ret < 0) throw std::runtime_error("avcodec_receive_frame: error during decoding: " + *Convert::ToString(ret));

													if (swsCtx == nullptr)
													{
														swsCtx = sws_getCachedContext(
															nullptr, codecParams->width, codecParams->height, codecCtx->pix_fmt,
															dstWidth, dstHeight, dstPixFmt, 0, nullptr, nullptr, nullptr);
														if (!swsCtx) throw std::runtime_error("sws_getCachedContext fail");
													}
							
													sws_scale(swsCtx, decFrame->data, decFrame->linesize, 0, decFrame->height, frame->data, frame->linesize);

													cv::Mat image(dstHeight, dstWidth, CV_8UC3, (frameBuf.data()), frame->linesize[0]);
													//cv::imshow("", image);
													//cv::waitKey(0);
													std::vector<uchar> rawData;
													imencode(".bmp", image, rawData);
													std::string dataStr;
													std::copy_n(rawData.begin(), rawData.size(), std::back_inserter(dataStr));
													rawData.clear();
													rawData.shrink_to_fit();
													const auto subPath = filePath / *Convert::ToString(codecCtx->frame_number);
													Log.Write<LogLevel::Info>(L"load file: " + subPath.wstring());
													img = ImageDatabase::Image(subPath, dataStr);
													sink = sink.resume();
												}
												av_packet_unref(pkt);
											} while (!eof);
											av_packet_free(&pkt);
										}
										catch (const std::exception& ex)
										{
											Log.Write<LogLevel::Error>(*Convert::ToWString(ex.what()) + L": " + filePath.wstring());
										}
										if (decFrame != nullptr) av_frame_free(&decFrame);
										if (frame != nullptr) av_frame_free(&frame);
										if (codecCtx != nullptr) avcodec_free_context(&codecCtx);
										if (fmtCtx != nullptr) avformat_close_input(&fmtCtx);
										if (swsCtx != nullptr) sws_freeContext(swsCtx);
									}
									else if (OpenCvUtility::IsImage(filePath.extension().string()))
									{
										Log.Write<LogLevel::Info>(L"load file: " + file.path().wstring());
										img = ImageDatabase::Image(file.path());
										sink = sink.resume();
									}
									else
									{
										Log.Write<LogLevel::Warn>(L"Unsupported format: " + file.path().wstring());
									}
								}
							}
							img = ImageDatabase::Image{};
							return std::move(sink);
						});

					for (uint64_t i = 0; !img.Path.empty(); source = source.resume(), ++i)
					{
						Log.Write<LogLevel::Info>(*Convert::ToWString(Convert::ToString(i)->c_str()) + L": compute md5: " + img.Path.wstring());
						img.ComputeMd5();
						Log.Write<LogLevel::Info>(L"compute vgg16: " + img.Path.wstring());
						try
						{
							img.ComputeVgg16();
							Log.Write<LogLevel::Info>(L"file: " + img.Path.wstring() + L": vgg16 start with: " + *Convert::ToWString(img.Vgg16[0]));
						}
						catch (const cv::Exception& ex)
						{
							Log.Write<LogLevel::Error>(L"compute vgg16: " + img.Path.wstring() + L": " + *Convert::ToWString(ex.what()));
						}
						img.FreeMemory();
						db.Images.push_back(img);
					}

					Log.Write<LogLevel::Info>(L"save database: " + dbPath.wstring());
					db.Save(dbPath);
					Log.Write<LogLevel::Info>(L"database size: " + *Convert::ToWString(Convert::ToString(db.Images.size())->c_str()));
				}},
				{Operator::query, [&]()
				{
					const auto dbPath = args.Value(dbArg);
					const auto input = args.Value(pathArg);
					ImageDatabase db(dbPath);
					Log.Write<LogLevel::Info>(L"load database: " + dbPath.wstring());
					db.Load(dbPath);
					Log.Write<LogLevel::Info>(L"database size: " + *Convert::ToWString(db.Images.size()));

					Log.Write<LogLevel::Info>(L"load database: " + input.wstring());
					ImageDatabase::Image img(input);
					Log.Write<LogLevel::Info>(L"compute md5: " + input.wstring());
					img.ComputeMd5();
					Log.Write<LogLevel::Info>(L"compute vgg16: " + input.wstring());
					img.ComputeVgg16();
					Log.Write<LogLevel::Info>(L"file: " + img.Path.wstring() + L": vgg16 start with: " + *Convert::ToWString(img.Vgg16[0]));
					img.FreeMemory();
					
					Log.Write<LogLevel::Info>(L"search start ...");
					Algorithm::Sort<true>(db.Images.begin(), db.Images.end(), [&](const ImageDatabase::Image& a, const ImageDatabase::Image& b)
					{
						return std::greater()(a.Vgg16.dot(img.Vgg16), b.Vgg16.dot(img.Vgg16));
					});

					Log.Write<LogLevel::Info>(L"search done.");
					for (const auto& i : db.Images)
					{
						if (const auto v = i.Vgg16.dot(img.Vgg16); v >= 0.8f)
						{
							Log.Write<LogLevel::Log>(L"found " + std::filesystem::path(*Convert::ToString(v)).wstring() + L": " + i.Path.wstring());
						}
						else
						{
							break;
						}
					}
				}}
			}[args.Value(opArg)]();
		}
		catch (const std::exception& ex)
		{
			Log.Write<LogLevel::Error>(*Convert::ToWString(ex.what()));
		}

		Log.Write<LogLevel::None>(L"{ok}.");
		logThread.join();
	}
#ifdef CatchEx
	catch (const std::exception& e)
	{
		Console::Error::WriteLine(e.what());
		Console::Error::WriteLine(args.GetDesc());
	}
#endif
}

ImageDatabase.h

#pragma once

#include <cstdint>
#include <filesystem>

#include <eigen3/Eigen/Eigen>
#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/dnn.hpp>

#include "Cryptography.h"
#include "OpenCvUtility.h"
#include "String.h"

class ImageDatabase
{
public:
	struct Image
	{
		std::filesystem::path Path{};
		std::string Md5{};
		Eigen::Matrix<float, 512, 1> Vgg16;

		explicit Image() = default;

		// Deserialization
		explicit Image(const uint8_t* data, const uint64_t pathLen)
		{
			Path = std::filesystem::u8path(std::string((char*)data, pathLen));
			Md5 = std::string_view((char*)data + pathLen, 32);
			Vgg16 = Eigen::Matrix<float, 512, 1>((float*)(data + pathLen + 32));
		}

		explicit Image(const std::filesystem::path& path, const std::string& data): Path(path), data(data)
		{
			
		}
		
		explicit Image(std::filesystem::path path) : Path(std::move(path))
		{
			data = OpenCvUtility::ReadToEnd(Path).value();
		}

	private:
		std::string data{};

	public:
		void ComputeVgg16()
		{
			static std::string ProtoTxt = OpenCvUtility::ReadToEnd(R"(vgg16-deploy.prototxt)").value();
			static std::string CaffeModel = OpenCvUtility::ReadToEnd(R"(vgg16.caffemodel)").value();

			static auto vgg16 = []()
			{
				auto vgg16 = cv::dnn::readNetFromCaffe(
					ProtoTxt.data(), ProtoTxt.length(),
					CaffeModel.data(), CaffeModel.length());
				vgg16.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
				vgg16.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA);
				return vgg16;
			}();
			vgg16.setInput(OpenCvUtility::ReadBlob(data.data(), data.length()));
			auto out = vgg16.forward();
			out = out / norm(out);
			for (int i = 0; i < 512; ++i)
			{
				Vgg16(i, 0) = out.at<float>(0, i, 0);
			}
			//std::cout << *Convert::ToString(Vgg16(0, 0)) << std::endl;
		}

		void ComputeMd5()
		{
			try
			{
				Cryptography::Md5 md5;
				md5.Append((std::uint8_t*)data.data(), data.length());
				Md5 = md5.Digest();
			}
			catch (...)
			{
				Md5 = std::string(32, 0);
			}
		}

		void Serialization(std::ofstream& fs)
		{
			const auto p = Path.u8string();
			const uint64_t pl = p.length();
			auto pls = *Convert::ToString(pl, 16);
			String::PadLeft(pls, 16, '0');
			fs.write(pls.data(), pls.length());
			fs.write(p.data(), p.length());
			fs.write(Md5.data(), 32);
			fs.write((char*)Vgg16.data(), 512 * sizeof(float));
		}

		void FreeMemory()
		{
			data = {};
		}
	};

	ImageDatabase(std::filesystem::path path) : path(std::move(path))
	{

	}

	void Load(const std::filesystem::path& path)
	{
		std::ifstream fs(path, std::ios::in | std::ios::binary);
		if (!fs) throw std::runtime_error("load file: bad stream");
		const auto fsbuf = std::make_unique<char[]>(4096);
		fs.rdbuf()->pubsetbuf(fsbuf.get(), 4096);

		while (!fs.eof())
		{
			char lenBuf[16 + 1]{0};
			fs.read(lenBuf, 16);
			if (fs.gcount() == 0) break;
			const auto pathLen = *Convert::FromString<uint64_t>(std::string(lenBuf, 16), 16);
			const auto len = pathLen + 32u + sizeof(float) * 512u;
			std::string buf(len, 0);
			fs.read(&buf[0], len);
			Images.push_back(Image((uint8_t*)buf.data(), pathLen));
		}

		fs.close();
	}

	void Save(const std::filesystem::path& path)
	{
		std::filesystem::remove(path);
		
		std::ofstream fs(path, std::ios::out | std::ios::binary);
		if (!fs) throw std::runtime_error("load file: bad stream");
		const auto fsbuf = std::make_unique<char[]>(4096);
		fs.rdbuf()->pubsetbuf(fsbuf.get(), 4096);

		for (auto img : Images)
		{
			img.Serialization(fs);
		}

		fs.close();
	}
private:
	std::filesystem::path path;

public:
	std::vector<Image> Images;
};

OpenCvUtility.h

#pragma once

#include <filesystem>
#include <optional>
#include <string>
#include <unordered_set>

#include <opencv2/core.hpp>

namespace OpenCvUtility
{
	bool IsImage(const std::string& extension);

	std::optional<std::string> ReadToEnd(const std::filesystem::path& path);

	cv::Mat ReadImage(const std::filesystem::path& file);

	cv::Mat ReadImage(const char* data, const uint64 len);

	cv::Mat ReadBlob(const cv::Mat& img);

	cv::Mat ReadBlob(const char* data, const uint64 len);

	cv::Mat ReadBlob(const std::filesystem::path& file);
}

OpenCvUtility.cpp

#include "OpenCvUtility.h"

#include <fstream>

#include <opencv2/imgcodecs.hpp>
#include <opencv2/dnn.hpp>

namespace OpenCvUtility
{
	bool IsImage(const std::string& extension)
	{
		static const std::unordered_set<std::string> ImageExtension {
			".bmp", ".dib",
			".jpeg", ".jpg", ".jpe",
			".jp2",
			".png",
			".webp",
			".pbm", ".pgm", ".ppm", ".pxm", ".pnm",
			".pfm",
			".sr", ".ras",
			".tiff", ".tif",
			".exr",
			".hdr", ".pic"
		};

		return ImageExtension.find(extension) != ImageExtension.end();
	}

	std::optional<std::string> ReadToEnd(const std::filesystem::path& path)
	{
		std::string data;

		std::ifstream fs(path, std::ios::in | std::ios::binary);
		if (!fs) return std::nullopt;

		constexpr auto fsBufSize = 4096;
		const auto fsBuf = std::make_unique<char[]>(fsBufSize);
		fs.rdbuf()->pubsetbuf(fsBuf.get(), fsBufSize);

		while (!fs.eof())
		{
			constexpr auto bufSize = 4096;
			char buf[bufSize];
			fs.read(buf, bufSize);
			data.append(std::string_view(buf, fs.gcount()));
		}

		fs.close();

		return { data };
	}

	cv::Mat ReadImage(const std::filesystem::path& file)
	{
		const auto raw = *ReadToEnd(file);
		std::vector<char> data;
		std::copy_n(raw.begin(), raw.length(), std::back_inserter(data));
		return cv::imdecode(data, 1);
	}

	cv::Mat ReadImage(const char* data, const uint64 len)
	{
		std::vector<char> buf;
		std::copy_n(data, len, std::back_inserter(buf));
		return cv::imdecode(buf, 1);
	}

	cv::Mat ReadBlob(const cv::Mat& img)
	{
		return cv::dnn::blobFromImage(img, 1., cv::Size(224, 224), cv::Scalar(123.68, 116.779, 103.939), false);
	}

	cv::Mat ReadBlob(const char* data, const uint64 len)
	{
		return ReadBlob(ReadImage(data, len));
	}

	cv::Mat ReadBlob(const std::filesystem::path& file)
	{
		return ReadBlob(ReadImage(file));
	}
}

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注