1#![allow(missing_docs)]
2
3use crate::class_prelude::*;
4use crate::descriptor::lang_id::LangID;
5use crate::device::{StringDescriptors, UsbDevice, UsbDeviceBuilder, UsbVidPid};
6use crate::Result;
7use core::cmp;
8
9#[cfg(feature = "test-class-high-speed")]
10mod sizes {
11 pub const BUFFER: usize = 2048;
12 pub const CONTROL_ENDPOINT: u8 = 64;
13 pub const BULK_ENDPOINT: u16 = 512;
14 pub const INTERRUPT_ENDPOINT: u16 = 1024;
15}
16
17#[cfg(not(feature = "test-class-high-speed"))]
18mod sizes {
19 pub const BUFFER: usize = 256;
20 pub const CONTROL_ENDPOINT: u8 = 8;
21 pub const BULK_ENDPOINT: u16 = 64;
22 pub const INTERRUPT_ENDPOINT: u16 = 31;
23}
24
25pub struct TestClass<'a, B: UsbBus> {
28 custom_string: StringIndex,
29 interface_string: StringIndex,
30 iface: InterfaceNumber,
31 ep_bulk_in: EndpointIn<'a, B>,
32 ep_bulk_out: EndpointOut<'a, B>,
33 ep_interrupt_in: EndpointIn<'a, B>,
34 ep_interrupt_out: EndpointOut<'a, B>,
35 ep_iso_in: EndpointIn<'a, B>,
36 control_buf: [u8; sizes::BUFFER],
37 bulk_buf: [u8; sizes::BUFFER],
38 interrupt_buf: [u8; sizes::BUFFER],
39 len: usize,
40 i: usize,
41 bench: bool,
42 expect_bulk_in_complete: bool,
43 expect_bulk_out: bool,
44 expect_interrupt_in_complete: bool,
45 expect_interrupt_out: bool,
46}
47
48pub const VID: u16 = 0x16c0;
49pub const PID: u16 = 0x05dc;
50pub const MANUFACTURER: &str = "TestClass Manufacturer";
51pub const PRODUCT: &str = "virkkunen.net usb-device TestClass";
52pub const SERIAL_NUMBER: &str = "TestClass Serial";
53pub const CUSTOM_STRING: &str = "TestClass Custom String";
54pub const INTERFACE_STRING: &str = "TestClass Interface";
55
56pub const REQ_STORE_REQUEST: u8 = 1;
57pub const REQ_READ_BUFFER: u8 = 2;
58pub const REQ_WRITE_BUFFER: u8 = 3;
59pub const REQ_SET_BENCH_ENABLED: u8 = 4;
60pub const REQ_READ_LONG_DATA: u8 = 5;
61pub const REQ_UNKNOWN: u8 = 42;
62
63pub const LONG_DATA: &[u8] = &[0x17; 257];
64
65impl<B: UsbBus> TestClass<'_, B> {
66 pub fn new(alloc: &UsbBusAllocator<B>) -> TestClass<'_, B> {
68 TestClass {
69 custom_string: alloc.string(),
70 interface_string: alloc.string(),
71 iface: alloc.interface(),
72 ep_bulk_in: alloc.bulk(sizes::BULK_ENDPOINT),
73 ep_bulk_out: alloc.bulk(sizes::BULK_ENDPOINT),
74 ep_interrupt_in: alloc.interrupt(sizes::INTERRUPT_ENDPOINT, 1),
75 ep_interrupt_out: alloc.interrupt(sizes::INTERRUPT_ENDPOINT, 1),
76 ep_iso_in: alloc.isochronous(
77 IsochronousSynchronizationType::Asynchronous,
78 IsochronousUsageType::ImplicitFeedbackData,
79 500, 1, ),
82 control_buf: [0; sizes::BUFFER],
83 bulk_buf: [0; sizes::BUFFER],
84 interrupt_buf: [0; sizes::BUFFER],
85 len: 0,
86 i: 0,
87 bench: false,
88 expect_bulk_in_complete: false,
89 expect_bulk_out: false,
90 expect_interrupt_in_complete: false,
91 expect_interrupt_out: false,
92 }
93 }
94
95 pub fn make_device<'a>(&self, usb_bus: &'a UsbBusAllocator<B>) -> UsbDevice<'a, B> {
97 self.make_device_builder(usb_bus).build()
98 }
99
100 pub fn make_device_builder<'a>(
112 &self,
113 usb_bus: &'a UsbBusAllocator<B>,
114 ) -> UsbDeviceBuilder<'a, B> {
115 UsbDeviceBuilder::new(usb_bus, UsbVidPid(VID, PID))
116 .strings(&[StringDescriptors::default()
117 .manufacturer(MANUFACTURER)
118 .product(PRODUCT)
119 .serial_number(SERIAL_NUMBER)])
120 .unwrap()
121 .max_packet_size_0(sizes::CONTROL_ENDPOINT)
122 .unwrap()
123 }
124
125 pub fn poll(&mut self) {
127 if self.bench {
128 match self.ep_bulk_out.read(&mut self.bulk_buf) {
129 Ok(_) | Err(UsbError::WouldBlock) => {}
130 Err(err) => panic!("bulk bench read {:?}", err),
131 };
132
133 match self
134 .ep_bulk_in
135 .write(&self.bulk_buf[0..self.ep_bulk_in.max_packet_size() as usize])
136 {
137 Ok(_) | Err(UsbError::WouldBlock) => {}
138 Err(err) => panic!("bulk bench write {:?}", err),
139 };
140
141 return;
142 }
143
144 let temp_i = self.i;
145 match self.ep_bulk_out.read(&mut self.bulk_buf[temp_i..]) {
146 Ok(count) => {
147 if self.expect_bulk_out {
148 self.expect_bulk_out = false;
149 } else {
150 panic!("unexpectedly read data from bulk out endpoint");
151 }
152
153 self.i += count;
154
155 if count < self.ep_bulk_out.max_packet_size() as usize {
156 self.len = self.i;
157 self.i = 0;
158
159 self.write_bulk_in(count == 0);
160 }
161 }
162 Err(UsbError::WouldBlock) => {}
163 Err(err) => panic!("bulk read {:?}", err),
164 };
165
166 match self.ep_interrupt_out.read(&mut self.interrupt_buf) {
167 Ok(count) => {
168 if self.expect_interrupt_out {
169 self.expect_interrupt_out = false;
170 } else {
171 panic!("unexpectedly read data from interrupt out endpoint");
172 }
173
174 self.ep_interrupt_in
175 .write(&self.interrupt_buf[0..count])
176 .expect("interrupt write");
177
178 self.expect_interrupt_in_complete = true;
179 }
180 Err(UsbError::WouldBlock) => {}
181 Err(err) => panic!("interrupt read {:?}", err),
182 };
183 }
184
185 fn write_bulk_in(&mut self, write_empty: bool) {
186 let to_write = cmp::min(
187 self.len - self.i,
188 self.ep_bulk_in.max_packet_size() as usize,
189 );
190
191 if to_write == 0 && !write_empty {
192 self.len = 0;
193 self.i = 0;
194
195 return;
196 }
197
198 match self
199 .ep_bulk_in
200 .write(&self.bulk_buf[self.i..self.i + to_write])
201 {
202 Ok(count) => {
203 assert_eq!(count, to_write);
204 self.expect_bulk_in_complete = true;
205 self.i += count;
206 }
207 Err(UsbError::WouldBlock) => {}
208 Err(err) => panic!("bulk write {:?}", err),
209 };
210 }
211}
212
213impl<B: UsbBus> UsbClass<B> for TestClass<'_, B> {
214 fn reset(&mut self) {
215 self.len = 0;
216 self.i = 0;
217 self.bench = false;
218 self.expect_bulk_in_complete = false;
219 self.expect_bulk_out = false;
220 self.expect_interrupt_in_complete = false;
221 self.expect_interrupt_out = false;
222 }
223
224 fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<()> {
225 writer.interface(self.iface, 0xff, 0x00, 0x00)?;
226 writer.endpoint(&self.ep_bulk_in)?;
227 writer.endpoint(&self.ep_bulk_out)?;
228 writer.endpoint(&self.ep_interrupt_in)?;
229 writer.endpoint(&self.ep_interrupt_out)?;
230 writer.interface_alt(self.iface, 1, 0xff, 0x01, 0x00, Some(self.interface_string))?;
231 writer.endpoint(&self.ep_iso_in)?;
232 Ok(())
233 }
234
235 fn get_string(&self, index: StringIndex, lang_id: LangID) -> Option<&str> {
236 if lang_id == LangID::EN_US {
237 if index == self.custom_string {
238 return Some(CUSTOM_STRING);
239 } else if index == self.interface_string {
240 return Some(INTERFACE_STRING);
241 }
242 }
243
244 None
245 }
246
247 fn endpoint_in_complete(&mut self, addr: EndpointAddress) {
248 if self.bench {
249 return;
250 }
251
252 if addr == self.ep_bulk_in.address() {
253 if self.expect_bulk_in_complete {
254 self.expect_bulk_in_complete = false;
255
256 self.write_bulk_in(false);
257 } else {
258 panic!("unexpected endpoint_in_complete");
259 }
260 } else if addr == self.ep_interrupt_in.address() {
261 if self.expect_interrupt_in_complete {
262 self.expect_interrupt_in_complete = false;
263 } else {
264 panic!("unexpected endpoint_in_complete");
265 }
266 }
267 }
268
269 fn endpoint_out(&mut self, addr: EndpointAddress) {
270 if addr == self.ep_bulk_out.address() {
271 self.expect_bulk_out = true;
272 } else if addr == self.ep_interrupt_out.address() {
273 self.expect_interrupt_out = true;
274 }
275 }
276
277 fn control_in(&mut self, xfer: ControlIn<B>) {
278 let req = *xfer.request();
279
280 if !(req.request_type == control::RequestType::Vendor
281 && req.recipient == control::Recipient::Device)
282 {
283 return;
284 }
285
286 match req.request {
287 REQ_READ_BUFFER if req.length as usize <= self.control_buf.len() => xfer
288 .accept_with(&self.control_buf[0..req.length as usize])
289 .expect("control_in REQ_READ_BUFFER failed"),
290 REQ_READ_LONG_DATA => xfer
291 .accept_with_static(LONG_DATA)
292 .expect("control_in REQ_READ_LONG_DATA failed"),
293 _ => xfer.reject().expect("control_in reject failed"),
294 }
295 }
296
297 fn control_out(&mut self, xfer: ControlOut<B>) {
298 let req = *xfer.request();
299
300 if !(req.request_type == control::RequestType::Vendor
301 && req.recipient == control::Recipient::Device)
302 {
303 return;
304 }
305
306 match req.request {
307 REQ_STORE_REQUEST => {
308 self.control_buf[0] =
309 (req.direction as u8) | (req.request_type as u8) << 5 | (req.recipient as u8);
310 self.control_buf[1] = req.request;
311 self.control_buf[2..4].copy_from_slice(&req.value.to_le_bytes());
312 self.control_buf[4..6].copy_from_slice(&req.index.to_le_bytes());
313 self.control_buf[6..8].copy_from_slice(&req.length.to_le_bytes());
314
315 xfer.accept().expect("control_out REQ_STORE_REQUEST failed");
316 }
317 REQ_WRITE_BUFFER if xfer.data().len() <= self.control_buf.len() => {
318 assert_eq!(
319 xfer.data().len(),
320 req.length as usize,
321 "xfer data len == req.length"
322 );
323
324 self.control_buf[0..xfer.data().len()].copy_from_slice(xfer.data());
325
326 xfer.accept().expect("control_out REQ_WRITE_BUFFER failed");
327 }
328 REQ_SET_BENCH_ENABLED => {
329 self.bench = req.value != 0;
330
331 xfer.accept()
332 .expect("control_out REQ_SET_BENCH_ENABLED failed");
333 }
334 _ => xfer.reject().expect("control_out reject failed"),
335 }
336 }
337}