@@ -15,6 +15,8 @@ class basic_scalar_mixin
1515 using data_ref = typename trait::ref_type;
1616 using data_t = typename trait::Data;
1717
18+ using Dim = typename S::dimension_type;
19+
1820 data_t data_;
1921
2022 protected:
@@ -33,23 +35,17 @@ class basic_scalar_mixin
3335
3436 basic_scalar_mixin (data_ptr data, const S &) : data_(data) {}
3537
38+ constexpr Dim size () const { return 1 ; }
39+
40+ constexpr auto dims () const { return S ().dims (); }
41+
3642 constexpr size_t data_size () const { return sizeof (R); }
3743
3844 data_ptr data () const { return data_.get (); }
3945
4046 data_ptr data_end () const { return data_.get () + 1 ; }
4147
4248 S shape () const { return S (); }
43-
44- void from_host (const void *data) const
45- {
46- basic_copier<D, host_memory>()(data_.get (), data, data_size ());
47- }
48-
49- void to_host (void *data) const
50- {
51- basic_copier<host_memory, D>()(data, data_.get (), data_size ());
52- }
5349};
5450
5551template <typename R, typename S, typename D, typename A>
@@ -121,6 +117,10 @@ class basic_tensor_mixin
121117
122118 static constexpr auto rank = S::rank;
123119
120+ Dim size () const { return shape_.size (); }
121+
122+ const auto &dims () const { return shape_.dims (); }
123+
124124 size_t data_size () const { return shape_.size () * sizeof (R); }
125125
126126 const S &shape () const { return shape_; }
@@ -158,16 +158,6 @@ class basic_tensor_mixin
158158 return slice_type (data_.get () + i * sub_shape.size (),
159159 batch (j - i, sub_shape));
160160 }
161-
162- void from_host (const void *data) const
163- {
164- basic_copier<D, host_memory>()(data_.get (), data, data_size ());
165- }
166-
167- void to_host (void *data) const
168- {
169- basic_copier<host_memory, D>()(data, data_.get (), data_size ());
170- }
171161};
172162} // namespace internal
173163} // namespace ttl
0 commit comments