00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <caffe/caffe.hpp>
00021
00022 #include <opencv2/core/core.hpp>
00023 #include <opencv2/highgui/highgui.hpp>
00024 #include <opencv2/imgproc/imgproc.hpp>
00025
00026 #include "Classifier.h"
00027
00028 using namespace caffe;
00029
00030
00031 Classifier::Classifier(const string& model_file,
00032 const string& trained_file,
00033 const string& mean_file,
00034 const string& label_file) {
00035 #ifdef CPU_ONLY
00036 Caffe::set_mode(Caffe::CPU);
00037 #else
00038 Caffe::set_mode(Caffe::GPU);
00039 #endif
00040
00041
00042 net_.reset(new Net<float>(model_file, TEST));
00043 net_->CopyTrainedLayersFrom(trained_file);
00044
00045 CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input.";
00046 CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output.";
00047
00048 Blob<float>* input_layer = net_->input_blobs()[0];
00049 num_channels_ = input_layer->channels();
00050 CHECK(num_channels_ == 3 || num_channels_ == 1)
00051 << "Input layer should have 1 or 3 channels.";
00052 input_geometry_ = cv::Size(input_layer->width(), input_layer->height());
00053
00054
00055 SetMean(mean_file);
00056
00057
00058 std::ifstream labels(label_file.c_str());
00059 CHECK(labels) << "Unable to open labels file " << label_file;
00060 string line;
00061 while (std::getline(labels, line))
00062 labels_.push_back(string(line));
00063
00064 Blob<float>* output_layer = net_->output_blobs()[0];
00065 CHECK_EQ(labels_.size(), output_layer->channels())
00066 << "Number of labels is different from the output layer dimension.";
00067 }
00068
00069 static bool PairCompare(const std::pair<float, int>& lhs,
00070 const std::pair<float, int>& rhs) {
00071 return lhs.first > rhs.first;
00072 }
00073
00074
00075
00076 static std::vector<int> Argmax(const std::vector<float>& v, int N) {
00077 std::vector<std::pair<float, int> > pairs;
00078 for (size_t i = 0; i < v.size(); ++i)
00079 pairs.push_back(std::make_pair(v[i], i));
00080 std::partial_sort(pairs.begin(), pairs.begin() + N, pairs.end(), PairCompare);
00081
00082 std::vector<int> result;
00083 for (int i = 0; i < N; ++i)
00084 result.push_back(pairs[i].second);
00085 return result;
00086 }
00087
00088
00089 std::vector<Prediction> Classifier::Classify(const cv::Mat& img, int N) {
00090 std::vector<float> output = Predict(img);
00091
00092 N = std::min<int>(labels_.size(), N);
00093 std::vector<int> maxN = Argmax(output, N);
00094 std::vector<Prediction> predictions;
00095 for (int i = 0; i < N; ++i) {
00096 int idx = maxN[i];
00097 predictions.push_back(std::make_pair(labels_[idx], output[idx]));
00098 }
00099
00100 return predictions;
00101 }
00102
00103
00104
00105 void Classifier::SetMean(const string& mean_file) {
00106 BlobProto blob_proto;
00107 ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
00108
00109
00110 Blob<float> mean_blob;
00111 mean_blob.FromProto(blob_proto);
00112 CHECK_EQ(mean_blob.channels(), num_channels_)
00113 << "Number of channels of mean file doesn't match input layer.";
00114
00115
00116 std::vector<cv::Mat> channels;
00117 float* data = mean_blob.mutable_cpu_data();
00118 for (int i = 0; i < num_channels_; ++i) {
00119
00120 cv::Mat channel(mean_blob.height(), mean_blob.width(), CV_32FC1, data);
00121 channels.push_back(channel);
00122 data += mean_blob.height() * mean_blob.width();
00123 }
00124
00125
00126 cv::Mat mean;
00127 cv::merge(channels, mean);
00128
00129
00130
00131 cv::Scalar channel_mean = cv::mean(mean);
00132 mean_ = cv::Mat(input_geometry_, mean.type(), channel_mean);
00133 }
00134
00135 std::vector<float> Classifier::Predict(const cv::Mat& img) {
00136 Blob<float>* input_layer = net_->input_blobs()[0];
00137 input_layer->Reshape(1, num_channels_,
00138 input_geometry_.height, input_geometry_.width);
00139
00140 net_->Reshape();
00141
00142 std::vector<cv::Mat> input_channels;
00143 WrapInputLayer(&input_channels);
00144
00145 Preprocess(img, &input_channels);
00146
00147 net_->ForwardPrefilled();
00148
00149
00150 Blob<float>* output_layer = net_->output_blobs()[0];
00151 const float* begin = output_layer->cpu_data();
00152 const float* end = begin + output_layer->channels();
00153 return std::vector<float>(begin, end);
00154 }
00155
00156
00157
00158
00159
00160
00161 void Classifier::WrapInputLayer(std::vector<cv::Mat>* input_channels) {
00162 Blob<float>* input_layer = net_->input_blobs()[0];
00163
00164 int width = input_layer->width();
00165 int height = input_layer->height();
00166 float* input_data = input_layer->mutable_cpu_data();
00167 for (int i = 0; i < input_layer->channels(); ++i) {
00168 cv::Mat channel(height, width, CV_32FC1, input_data);
00169 input_channels->push_back(channel);
00170 input_data += width * height;
00171 }
00172 }
00173
00174
00175 void Classifier::Preprocess(const cv::Mat& img,
00176 std::vector<cv::Mat>* input_channels) {
00177
00178 cv::Mat sample;
00179 if (img.channels() == 3 && num_channels_ == 1)
00180 cv::cvtColor(img, sample, CV_BGR2GRAY);
00181 else if (img.channels() == 4 && num_channels_ == 1)
00182 cv::cvtColor(img, sample, CV_BGRA2GRAY);
00183 else if (img.channels() == 4 && num_channels_ == 3)
00184 cv::cvtColor(img, sample, CV_BGRA2BGR);
00185 else if (img.channels() == 1 && num_channels_ == 3)
00186 cv::cvtColor(img, sample, CV_GRAY2BGR);
00187 else
00188 sample = img;
00189
00190 cv::Mat sample_resized;
00191 if (sample.size() != input_geometry_)
00192 cv::resize(sample, sample_resized, input_geometry_);
00193 else
00194 sample_resized = sample;
00195
00196 cv::Mat sample_float;
00197 if (num_channels_ == 3)
00198 sample_resized.convertTo(sample_float, CV_32FC3);
00199 else
00200 sample_resized.convertTo(sample_float, CV_32FC1);
00201
00202 cv::Mat sample_normalized;
00203 cv::subtract(sample_float, mean_, sample_normalized);
00204
00205
00206
00207
00208 cv::split(sample_normalized, *input_channels);
00209
00210 CHECK(reinterpret_cast<float*>(input_channels->at(0).data)
00211 == net_->input_blobs()[0]->cpu_data())
00212 << "Input channels are not wrapping the input layer of the network.";
00213 }