|
|
#pragma once
|
|
|
|
|
|
#include <ATen/quantized/Quantizer.h>
|
|
|
#include <c10/core/TensorImpl.h>
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
|
|
namespace at {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API QTensorImpl : public c10::TensorImpl {
|
|
|
public:
|
|
|
QTensorImpl(
|
|
|
Storage&& storage,
|
|
|
DispatchKeySet key_set,
|
|
|
const caffe2::TypeMeta data_type,
|
|
|
QuantizerPtr quantizer);
|
|
|
|
|
|
|
|
|
QTensorImpl(
|
|
|
ImplType type,
|
|
|
Storage&& storage,
|
|
|
DispatchKeySet key_set,
|
|
|
const caffe2::TypeMeta data_type,
|
|
|
QuantizerPtr quantizer);
|
|
|
|
|
|
|
|
|
|
|
|
QuantizerPtr quantizer() {
|
|
|
return quantizer_;
|
|
|
}
|
|
|
|
|
|
void set_quantizer_(QuantizerPtr quantizer) {
|
|
|
quantizer_ = quantizer;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
|
|
const c10::VariableVersion& version_counter,
|
|
|
bool allow_tensor_metadata_change) const override {
|
|
|
auto impl = c10::make_intrusive<QTensorImpl>(
|
|
|
Storage(storage()), key_set(), data_type_, quantizer_);
|
|
|
copy_tensor_metadata(
|
|
|
this,
|
|
|
impl.get(),
|
|
|
version_counter,
|
|
|
allow_tensor_metadata_change);
|
|
|
impl->refresh_numel();
|
|
|
impl->refresh_contiguous();
|
|
|
return impl;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
|
|
c10::VariableVersion&& version_counter,
|
|
|
bool allow_tensor_metadata_change) const override {
|
|
|
auto impl = c10::make_intrusive<QTensorImpl>(
|
|
|
Storage(storage()), key_set(), data_type_, quantizer_);
|
|
|
copy_tensor_metadata(
|
|
|
this,
|
|
|
impl.get(),
|
|
|
std::move(version_counter),
|
|
|
allow_tensor_metadata_change);
|
|
|
impl->refresh_numel();
|
|
|
impl->refresh_contiguous();
|
|
|
return impl;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
|
|
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
|
|
|
auto q_impl = static_cast<const QTensorImpl*>(impl.get());
|
|
|
copy_tensor_metadata(
|
|
|
q_impl,
|
|
|
this,
|
|
|
version_counter(),
|
|
|
allow_tensor_metadata_change());
|
|
|
refresh_numel();
|
|
|
refresh_contiguous();
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
QuantizerPtr quantizer_;
|
|
|
|
|
|
const char* tensorimpl_type_name() const override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void copy_tensor_metadata(
|
|
|
const QTensorImpl* src_q_impl,
|
|
|
QTensorImpl* dest_q_impl,
|
|
|
const c10::VariableVersion& version_counter,
|
|
|
bool allow_tensor_metadata_change) {
|
|
|
TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change);
|
|
|
|
|
|
|
|
|
dest_q_impl->quantizer_ = src_q_impl->quantizer_;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|