博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TFLite: 从内存生成FlatBufferModel
阅读量:742 次
发布时间:2019-03-23

本文共 4689 字,大约阅读时间需要 15 分钟。

BuildFromBuffer 

  // Builds a model based on a pre-loaded flatbuffer. The caller retains

  // ownership of the buffer and should keep it alive until the returned object
  // is destroyed. Caller retains ownership of `error_reporter` and must ensure
  // its lifetime is longer than the FlatBufferModel instance.
  // Returns a nullptr in case of failure.

 BuildFromBuffer:声明

  static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
      const char* buffer, size_t buffer_size,
      ErrorReporter* error_reporter = DefaultErrorReporter());

BuildFromBuffer定义

  std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
      const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
    error_reporter = ValidateErrorReporter(error_reporter);
  
    std::unique_ptr<FlatBufferModel> model;
    Allocation* allocation =
        new MemoryAllocation(buffer, buffer_size, error_reporter);
    model.reset(new FlatBufferModel(allocation, error_reporter));
    if (!model->initialized()) model.reset();
    return model;
  }

MemoryAllocation

class MemoryAllocation : public Allocation {

 public:
  // Allocates memory with the pointer and the number of bytes of the memory.
  // The pointer has to remain alive and unchanged until the destructor is
  // called.
  MemoryAllocation(const void* ptr, size_t num_bytes,                                                                   
                   ErrorReporter* error_reporter);
  virtual ~MemoryAllocation();
  const void* base() const override;
  size_t bytes() const override;
  bool valid() const override;

 private:

  const void* buffer_;
  size_t buffer_size_bytes_ = 0;
};

// MemoryAllocation实现

{
    MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,                                                                     
                                   ErrorReporter* error_reporter)
        : Allocation(error_reporter) {
      buffer_ = ptr;
      buffer_size_bytes_ = num_bytes;
    }

    MemoryAllocation::~MemoryAllocation() {}

    const void* MemoryAllocation::base() const { return buffer_; }

    size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; }

    bool MemoryAllocation::valid() const { return buffer_ != nullptr; }

}

基类Allocation

// A memory allocation handle. This could be a mmap or shared memory.

class Allocation {
 public:  
  Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {}
  virtual ~Allocation() {}

  // Base pointer of this allocation

  virtual const void* base() const = 0;                                             
  // Size in bytes of the allocation
  virtual size_t bytes() const = 0;
  // Whether the allocation is valid
  virtual bool valid() const = 0;

 protected:

  ErrorReporter* error_reporter_;
};

FlatBufferModel

生成Allocation后,就可以调用FlatBufferModel的构造函数

  FlatBufferModel::FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter)

      : error_reporter_(ValidateErrorReporter(error_reporter)) {
    allocation_ = allocation;
    if (!allocation_->valid() || !CheckModelIdentifier()) return;
  
    model_ = ::tflite::GetModel(allocation_->base());                                                                                                      
  }

怎样把tflite文件保存到memory中

class FileCopyAllocation : public Allocation {

 public:
  FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
  virtual ~FileCopyAllocation();
  const void* base() const override;
  size_t bytes() const override;
  bool valid() const override;

 private:

  // Data required for mmap.
  std::unique_ptr<const char[]> copied_buffer_;
  size_t buffer_size_bytes_ = 0;
};

FileCopyAllocation::FileCopyAllocation

构造函数赋值private 成员:copied_buffer_ and buffer_size_bytes_

通过public 接口:base and bytes访问private member

FileCopyAllocation::FileCopyAllocation(const char* filename,

                                       ErrorReporter* error_reporter)
    : Allocation(error_reporter) {
  // Obtain the file size, using an alternative method that is does not
  // require fstat for more compatibility.
  std::unique_ptr<FILE, decltype(&fclose)> file(fopen(filename, "rb"), fclose);
  if (!file) {
    error_reporter_->Report("Could not open '%s'.", filename);
    return;
  }
  // TODO(ahentz): Why did you think using fseek here was better for finding
  // the size?
  struct stat sb; 
  if (fstat(fileno(file.get()), &sb) != 0) {
    error_reporter_->Report("Failed to get file size of '%s'.", filename);
    return;
  }
  buffer_size_bytes_ = sb.st_size;
  std::unique_ptr<char[]> buffer(new char[buffer_size_bytes_]);
  if (!buffer) {
    error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.",
                            filename);
    return;
  }
  size_t bytes_read =
      fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get());
  if (bytes_read != buffer_size_bytes_) {
    error_reporter_->Report("Read of '%s' failed (too few bytes read).",
                            filename);
    return;
  }
  // Versions of GCC before 6.2.0 don't support std::move from non-const
  // char[] to const char[] unique_ptrs.
  copied_buffer_.reset(const_cast<char const*>(buffer.release()));
}

const void* FileCopyAllocation::base() const { return copied_buffer_.get(); }

size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } 

目的是从tflite文件得到数据传输,得到传输数据后根据buffer创建FlatBufferModel

 

 

转载地址:http://ejuzk.baihongyu.com/

你可能感兴趣的文章