usb_device/
test_class.rs

1#![allow(missing_docs)]
2
3use crate::class_prelude::*;
4use crate::descriptor::lang_id::LangID;
5use crate::device::{StringDescriptors, UsbDevice, UsbDeviceBuilder, UsbVidPid};
6use crate::Result;
7use core::cmp;
8
9#[cfg(feature = "test-class-high-speed")]
10mod sizes {
11    pub const BUFFER: usize = 2048;
12    pub const CONTROL_ENDPOINT: u8 = 64;
13    pub const BULK_ENDPOINT: u16 = 512;
14    pub const INTERRUPT_ENDPOINT: u16 = 1024;
15}
16
17#[cfg(not(feature = "test-class-high-speed"))]
18mod sizes {
19    pub const BUFFER: usize = 256;
20    pub const CONTROL_ENDPOINT: u8 = 8;
21    pub const BULK_ENDPOINT: u16 = 64;
22    pub const INTERRUPT_ENDPOINT: u16 = 31;
23}
24
25/// Test USB class for testing USB driver implementations. Supports various endpoint types and
26/// requests for testing USB peripheral drivers on actual hardware.
27pub struct TestClass<'a, B: UsbBus> {
28    custom_string: StringIndex,
29    interface_string: StringIndex,
30    iface: InterfaceNumber,
31    ep_bulk_in: EndpointIn<'a, B>,
32    ep_bulk_out: EndpointOut<'a, B>,
33    ep_interrupt_in: EndpointIn<'a, B>,
34    ep_interrupt_out: EndpointOut<'a, B>,
35    ep_iso_in: EndpointIn<'a, B>,
36    control_buf: [u8; sizes::BUFFER],
37    bulk_buf: [u8; sizes::BUFFER],
38    interrupt_buf: [u8; sizes::BUFFER],
39    len: usize,
40    i: usize,
41    bench: bool,
42    expect_bulk_in_complete: bool,
43    expect_bulk_out: bool,
44    expect_interrupt_in_complete: bool,
45    expect_interrupt_out: bool,
46}
47
48pub const VID: u16 = 0x16c0;
49pub const PID: u16 = 0x05dc;
50pub const MANUFACTURER: &str = "TestClass Manufacturer";
51pub const PRODUCT: &str = "virkkunen.net usb-device TestClass";
52pub const SERIAL_NUMBER: &str = "TestClass Serial";
53pub const CUSTOM_STRING: &str = "TestClass Custom String";
54pub const INTERFACE_STRING: &str = "TestClass Interface";
55
56pub const REQ_STORE_REQUEST: u8 = 1;
57pub const REQ_READ_BUFFER: u8 = 2;
58pub const REQ_WRITE_BUFFER: u8 = 3;
59pub const REQ_SET_BENCH_ENABLED: u8 = 4;
60pub const REQ_READ_LONG_DATA: u8 = 5;
61pub const REQ_UNKNOWN: u8 = 42;
62
63pub const LONG_DATA: &[u8] = &[0x17; 257];
64
65impl<B: UsbBus> TestClass<'_, B> {
66    /// Creates a new TestClass.
67    pub fn new(alloc: &UsbBusAllocator<B>) -> TestClass<'_, B> {
68        TestClass {
69            custom_string: alloc.string(),
70            interface_string: alloc.string(),
71            iface: alloc.interface(),
72            ep_bulk_in: alloc.bulk(sizes::BULK_ENDPOINT),
73            ep_bulk_out: alloc.bulk(sizes::BULK_ENDPOINT),
74            ep_interrupt_in: alloc.interrupt(sizes::INTERRUPT_ENDPOINT, 1),
75            ep_interrupt_out: alloc.interrupt(sizes::INTERRUPT_ENDPOINT, 1),
76            ep_iso_in: alloc.isochronous(
77                IsochronousSynchronizationType::Asynchronous,
78                IsochronousUsageType::ImplicitFeedbackData,
79                500, // These last two args are arbitrary in this usage, they
80                1,   // let the host know how much bandwidth to reserve.
81            ),
82            control_buf: [0; sizes::BUFFER],
83            bulk_buf: [0; sizes::BUFFER],
84            interrupt_buf: [0; sizes::BUFFER],
85            len: 0,
86            i: 0,
87            bench: false,
88            expect_bulk_in_complete: false,
89            expect_bulk_out: false,
90            expect_interrupt_in_complete: false,
91            expect_interrupt_out: false,
92        }
93    }
94
95    /// Convenience method to create a UsbDevice that is configured correctly for TestClass.
96    pub fn make_device<'a>(&self, usb_bus: &'a UsbBusAllocator<B>) -> UsbDevice<'a, B> {
97        self.make_device_builder(usb_bus).build()
98    }
99
100    /// Convenience method to create a UsbDeviceBuilder that is configured correctly for TestClass.
101    ///
102    /// The methods sets
103    ///
104    /// - manufacturer
105    /// - product
106    /// - serial number
107    /// - max_packet_size_0
108    ///
109    /// on the returned builder. If you change the manufacturer, product, or serial number fields,
110    /// the test host may misbehave.
111    pub fn make_device_builder<'a>(
112        &self,
113        usb_bus: &'a UsbBusAllocator<B>,
114    ) -> UsbDeviceBuilder<'a, B> {
115        UsbDeviceBuilder::new(usb_bus, UsbVidPid(VID, PID))
116            .strings(&[StringDescriptors::default()
117                .manufacturer(MANUFACTURER)
118                .product(PRODUCT)
119                .serial_number(SERIAL_NUMBER)])
120            .unwrap()
121            .max_packet_size_0(sizes::CONTROL_ENDPOINT)
122            .unwrap()
123    }
124
125    /// Must be called after polling the UsbDevice.
126    pub fn poll(&mut self) {
127        if self.bench {
128            match self.ep_bulk_out.read(&mut self.bulk_buf) {
129                Ok(_) | Err(UsbError::WouldBlock) => {}
130                Err(err) => panic!("bulk bench read {:?}", err),
131            };
132
133            match self
134                .ep_bulk_in
135                .write(&self.bulk_buf[0..self.ep_bulk_in.max_packet_size() as usize])
136            {
137                Ok(_) | Err(UsbError::WouldBlock) => {}
138                Err(err) => panic!("bulk bench write {:?}", err),
139            };
140
141            return;
142        }
143
144        let temp_i = self.i;
145        match self.ep_bulk_out.read(&mut self.bulk_buf[temp_i..]) {
146            Ok(count) => {
147                if self.expect_bulk_out {
148                    self.expect_bulk_out = false;
149                } else {
150                    panic!("unexpectedly read data from bulk out endpoint");
151                }
152
153                self.i += count;
154
155                if count < self.ep_bulk_out.max_packet_size() as usize {
156                    self.len = self.i;
157                    self.i = 0;
158
159                    self.write_bulk_in(count == 0);
160                }
161            }
162            Err(UsbError::WouldBlock) => {}
163            Err(err) => panic!("bulk read {:?}", err),
164        };
165
166        match self.ep_interrupt_out.read(&mut self.interrupt_buf) {
167            Ok(count) => {
168                if self.expect_interrupt_out {
169                    self.expect_interrupt_out = false;
170                } else {
171                    panic!("unexpectedly read data from interrupt out endpoint");
172                }
173
174                self.ep_interrupt_in
175                    .write(&self.interrupt_buf[0..count])
176                    .expect("interrupt write");
177
178                self.expect_interrupt_in_complete = true;
179            }
180            Err(UsbError::WouldBlock) => {}
181            Err(err) => panic!("interrupt read {:?}", err),
182        };
183    }
184
185    fn write_bulk_in(&mut self, write_empty: bool) {
186        let to_write = cmp::min(
187            self.len - self.i,
188            self.ep_bulk_in.max_packet_size() as usize,
189        );
190
191        if to_write == 0 && !write_empty {
192            self.len = 0;
193            self.i = 0;
194
195            return;
196        }
197
198        match self
199            .ep_bulk_in
200            .write(&self.bulk_buf[self.i..self.i + to_write])
201        {
202            Ok(count) => {
203                assert_eq!(count, to_write);
204                self.expect_bulk_in_complete = true;
205                self.i += count;
206            }
207            Err(UsbError::WouldBlock) => {}
208            Err(err) => panic!("bulk write {:?}", err),
209        };
210    }
211}
212
213impl<B: UsbBus> UsbClass<B> for TestClass<'_, B> {
214    fn reset(&mut self) {
215        self.len = 0;
216        self.i = 0;
217        self.bench = false;
218        self.expect_bulk_in_complete = false;
219        self.expect_bulk_out = false;
220        self.expect_interrupt_in_complete = false;
221        self.expect_interrupt_out = false;
222    }
223
224    fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<()> {
225        writer.interface(self.iface, 0xff, 0x00, 0x00)?;
226        writer.endpoint(&self.ep_bulk_in)?;
227        writer.endpoint(&self.ep_bulk_out)?;
228        writer.endpoint(&self.ep_interrupt_in)?;
229        writer.endpoint(&self.ep_interrupt_out)?;
230        writer.interface_alt(self.iface, 1, 0xff, 0x01, 0x00, Some(self.interface_string))?;
231        writer.endpoint(&self.ep_iso_in)?;
232        Ok(())
233    }
234
235    fn get_string(&self, index: StringIndex, lang_id: LangID) -> Option<&str> {
236        if lang_id == LangID::EN_US {
237            if index == self.custom_string {
238                return Some(CUSTOM_STRING);
239            } else if index == self.interface_string {
240                return Some(INTERFACE_STRING);
241            }
242        }
243
244        None
245    }
246
247    fn endpoint_in_complete(&mut self, addr: EndpointAddress) {
248        if self.bench {
249            return;
250        }
251
252        if addr == self.ep_bulk_in.address() {
253            if self.expect_bulk_in_complete {
254                self.expect_bulk_in_complete = false;
255
256                self.write_bulk_in(false);
257            } else {
258                panic!("unexpected endpoint_in_complete");
259            }
260        } else if addr == self.ep_interrupt_in.address() {
261            if self.expect_interrupt_in_complete {
262                self.expect_interrupt_in_complete = false;
263            } else {
264                panic!("unexpected endpoint_in_complete");
265            }
266        }
267    }
268
269    fn endpoint_out(&mut self, addr: EndpointAddress) {
270        if addr == self.ep_bulk_out.address() {
271            self.expect_bulk_out = true;
272        } else if addr == self.ep_interrupt_out.address() {
273            self.expect_interrupt_out = true;
274        }
275    }
276
277    fn control_in(&mut self, xfer: ControlIn<B>) {
278        let req = *xfer.request();
279
280        if !(req.request_type == control::RequestType::Vendor
281            && req.recipient == control::Recipient::Device)
282        {
283            return;
284        }
285
286        match req.request {
287            REQ_READ_BUFFER if req.length as usize <= self.control_buf.len() => xfer
288                .accept_with(&self.control_buf[0..req.length as usize])
289                .expect("control_in REQ_READ_BUFFER failed"),
290            REQ_READ_LONG_DATA => xfer
291                .accept_with_static(LONG_DATA)
292                .expect("control_in REQ_READ_LONG_DATA failed"),
293            _ => xfer.reject().expect("control_in reject failed"),
294        }
295    }
296
297    fn control_out(&mut self, xfer: ControlOut<B>) {
298        let req = *xfer.request();
299
300        if !(req.request_type == control::RequestType::Vendor
301            && req.recipient == control::Recipient::Device)
302        {
303            return;
304        }
305
306        match req.request {
307            REQ_STORE_REQUEST => {
308                self.control_buf[0] =
309                    (req.direction as u8) | (req.request_type as u8) << 5 | (req.recipient as u8);
310                self.control_buf[1] = req.request;
311                self.control_buf[2..4].copy_from_slice(&req.value.to_le_bytes());
312                self.control_buf[4..6].copy_from_slice(&req.index.to_le_bytes());
313                self.control_buf[6..8].copy_from_slice(&req.length.to_le_bytes());
314
315                xfer.accept().expect("control_out REQ_STORE_REQUEST failed");
316            }
317            REQ_WRITE_BUFFER if xfer.data().len() <= self.control_buf.len() => {
318                assert_eq!(
319                    xfer.data().len(),
320                    req.length as usize,
321                    "xfer data len == req.length"
322                );
323
324                self.control_buf[0..xfer.data().len()].copy_from_slice(xfer.data());
325
326                xfer.accept().expect("control_out REQ_WRITE_BUFFER failed");
327            }
328            REQ_SET_BENCH_ENABLED => {
329                self.bench = req.value != 0;
330
331                xfer.accept()
332                    .expect("control_out REQ_SET_BENCH_ENABLED failed");
333            }
334            _ => xfer.reject().expect("control_out reject failed"),
335        }
336    }
337}