| | #include <stdint.h> |
| | #include <algorithm> |
| | #include <string> |
| | #include <utility> |
| | #include <vector> |
| |
|
| | #include "boost/scoped_ptr.hpp" |
| | #include "gflags/gflags.h" |
| | #include "glog/logging.h" |
| |
|
| | #include "caffe/proto/caffe.pb.h" |
| | #include "caffe/util/db.hpp" |
| | #include "caffe/util/io.hpp" |
| |
|
| | using namespace caffe; |
| |
|
| | using std::max; |
| | using std::pair; |
| | using boost::scoped_ptr; |
| |
|
| | DEFINE_string(backend, "lmdb", |
| | "The backend {leveldb, lmdb} containing the images"); |
| |
|
| | int main(int argc, char** argv) { |
| | #ifdef USE_OPENCV |
| | ::google::InitGoogleLogging(argv[0]); |
| | |
| | FLAGS_alsologtostderr = 1; |
| |
|
| | #ifndef GFLAGS_GFLAGS_H_ |
| | namespace gflags = google; |
| | #endif |
| |
|
| | gflags::SetUsageMessage("Compute the mean_image of a set of images given by" |
| | " a leveldb/lmdb\n" |
| | "Usage:\n" |
| | " compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n"); |
| |
|
| | gflags::ParseCommandLineFlags(&argc, &argv, true); |
| |
|
| | if (argc < 2 || argc > 3) { |
| | gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean"); |
| | return 1; |
| | } |
| |
|
| | scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend)); |
| | db->Open(argv[1], db::READ); |
| | scoped_ptr<db::Cursor> cursor(db->NewCursor()); |
| |
|
| | BlobProto sum_blob; |
| | int count = 0; |
| | |
| | Datum datum; |
| | datum.ParseFromString(cursor->value()); |
| |
|
| | if (DecodeDatumNative(&datum)) { |
| | LOG(INFO) << "Decoding Datum"; |
| | } |
| |
|
| | sum_blob.set_num(1); |
| | sum_blob.set_channels(datum.channels()); |
| | sum_blob.set_height(datum.height()); |
| | sum_blob.set_width(datum.width()); |
| | const int data_size = datum.channels() * datum.height() * datum.width(); |
| | int size_in_datum = std::max<int>(datum.data().size(), |
| | datum.float_data_size()); |
| | for (int i = 0; i < size_in_datum; ++i) { |
| | sum_blob.add_data(0.); |
| | } |
| | LOG(INFO) << "Starting iteration"; |
| | while (cursor->valid()) { |
| | Datum datum; |
| | datum.ParseFromString(cursor->value()); |
| | DecodeDatumNative(&datum); |
| |
|
| | const std::string& data = datum.data(); |
| | size_in_datum = std::max<int>(datum.data().size(), |
| | datum.float_data_size()); |
| | CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " << |
| | size_in_datum; |
| | if (data.size() != 0) { |
| | CHECK_EQ(data.size(), size_in_datum); |
| | for (int i = 0; i < size_in_datum; ++i) { |
| | sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); |
| | } |
| | } else { |
| | CHECK_EQ(datum.float_data_size(), size_in_datum); |
| | for (int i = 0; i < size_in_datum; ++i) { |
| | sum_blob.set_data(i, sum_blob.data(i) + |
| | static_cast<float>(datum.float_data(i))); |
| | } |
| | } |
| | ++count; |
| | if (count % 10000 == 0) { |
| | LOG(INFO) << "Processed " << count << " files."; |
| | } |
| | cursor->Next(); |
| | } |
| |
|
| | if (count % 10000 != 0) { |
| | LOG(INFO) << "Processed " << count << " files."; |
| | } |
| | for (int i = 0; i < sum_blob.data_size(); ++i) { |
| | sum_blob.set_data(i, sum_blob.data(i) / count); |
| | } |
| | |
| | if (argc == 3) { |
| | LOG(INFO) << "Write to " << argv[2]; |
| | WriteProtoToBinaryFile(sum_blob, argv[2]); |
| | } |
| | const int channels = sum_blob.channels(); |
| | const int dim = sum_blob.height() * sum_blob.width(); |
| | std::vector<float> mean_values(channels, 0.0); |
| | LOG(INFO) << "Number of channels: " << channels; |
| | for (int c = 0; c < channels; ++c) { |
| | for (int i = 0; i < dim; ++i) { |
| | mean_values[c] += sum_blob.data(dim * c + i); |
| | } |
| | LOG(INFO) << "mean_value channel [" << c << "]: " << mean_values[c] / dim; |
| | } |
| | #else |
| | LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV."; |
| | #endif |
| | return 0; |
| | } |
| |
|