1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
// https://www.nxp.com/docs/en/reference-manual/MCUBOOTRM.pdf
//
// - all fields in packets are little-endian
// - each command sent from host is replied to with response
// - optional data phase, either command or response (not both!)
//   RM uses "incoming" (host->MCU) and "outgoing" (host<-MCU) terminology
//
//
// 1) no data phase:
//   --> command
//   <-- generic response
//
//
// 2) command data phase:
//   --> command (has-data-phase flag set)
//   <-- initial generic response (must signal success status to proceed with data phase)
//   ==> inital command data packet
//   ⋮
//   ==> final command data packet
//   <-- final generic response (contains status for entire operation)
//
//  Device may abort data phase by sending finale generic response early, with status abort
//
//
// 3) response data phase:
//   --> command
//   <-- initial non-generic response (must signal has-data to proceed with data phase)
//   <== initial response data packet
//    ⋮
//   <== final reponse data packet
//   <-- final generic response (contains status for entire operation)
//
//  Device may abort data phase early by sending zero-length packet
//  Host may abort data phase by sending generic response (?is this a thing?)

use super::Error as BootloaderError;
use crate::bootloader::{command, property};
use core::convert::{TryFrom, TryInto};

use hidapi::{HidDevice, HidResult};

/// The NXP bootloader protocol. Interact via `fn call(Command) -> Result<Response>`
pub struct Protocol {
    device: HidDevice,
}

/// The NXP bootloader protocol error type
#[derive(thiserror::Error, Debug)]
pub enum Error {
    #[error("receiver aborted data phase")]
    AbortDataPhase,
    #[error("expected data response packet")]
    ExpectedDataPacket,
    #[error("expected (non-data) response packet")]
    ExpectedResponsePacket,
    #[error("error from underlying hidapi")]
    HidApi(#[from] hidapi::HidError),
    #[error("invalid HID report ID ({0})")]
    InvalidReportId(u8),
    #[error("unknown response tag ({0})")]
    UnknownResponseTag(u8),

    #[error("unspecified protocol error")]
    Unspecified,
}

/// The NXP bootloader protocol result type, with split status as error
pub type Result<T> = std::result::Result<T, Error>;

pub struct ResponsePacket {
    pub tag: command::ResponseTag,
    pub has_data: bool,
    pub status: Option<BootloaderError>,
    // pub mirrored_command_header: [u8; 4],
    pub parameters: Vec<u32>,
}

pub enum ReceivedPacket {
    Response(ResponsePacket),
    Data(Vec<u8>),
}

impl TryFrom<ReceivedPacket> for ResponsePacket {
    type Error = Error;
    fn try_from(packet: ReceivedPacket) -> Result<Self> {
        if let ReceivedPacket::Response(packet) = packet {
            Ok(packet)
        } else {
            Err(Error::ExpectedResponsePacket)
        }
    }
}

impl TryFrom<ReceivedPacket> for Vec<u8> {
    type Error = Error;
    fn try_from(packet: ReceivedPacket) -> Result<Self> {
        if let ReceivedPacket::Data(data) = packet {
            Ok(data)
        } else {
            Err(Error::ExpectedDataPacket)
        }
    }
}

pub const READ_TIMEOUT: i32 = 2000;

impl Protocol {
    pub fn property(
        &self,
        property: property::Property,
    ) -> core::result::Result<Vec<u32>, crate::bootloader::Error> {
        let response = self
            .call(&command::Command::GetProperty(property))
            .expect("success");
        if let command::Response::GetProperty(values) = response {
            Ok(values)
        } else {
            todo!();
        }
    }

    pub fn call(&self, command: &command::Command) -> Result<command::Response> {
        self.call_progress(command, None)
    }

    pub fn call_progress<'a>(
        &self,
        command: &command::Command,
        progress: Option<&'a dyn Fn(usize)>,
    ) -> Result<command::Response> {
        // construct command packet
        let command_packet = command.hid_packet();

        // send command packet
        self.write(command_packet.as_slice())?;
        trace!("--> {}", hex_str!(&command_packet));

        let initial_response = self.read_packet()?;

        // parse initial reponse packet
        match (command.clone(), command.tag(), command.data_phase()) {
            // case 1: no data phases
            (command, _tag, command::DataPhase::None) => {
                // we expect a non-data packet, not signaling additional data packets, with
                // successful status, mirroring our command header
                let packet = ResponsePacket::try_from(initial_response)?;

                assert!(!packet.has_data);
                if let Some(status) = packet.status {
                    panic!("{:?}", status);
                }

                use command::Command::*;
                match command {
                    Reset
                    | EraseFlash {
                        address: _,
                        length: _,
                    }
                    | EraseFlashAll
                    | ConfigureMemory { .. }
                    | Keystore(command::KeystoreOperation::Enroll)
                    | Keystore(command::KeystoreOperation::GenerateKey { key: _, len: _ })
                    | Keystore(command::KeystoreOperation::WriteNonVolatile)
                    | Keystore(command::KeystoreOperation::ReadNonVolatile) => {
                        assert_eq!(packet.tag, command::ResponseTag::Generic);
                        // general property of generic responses: 2 parameters, status and mirrored command header
                        assert_eq!(packet.parameters.len(), 1);
                        assert_eq!(
                            packet.parameters[0].to_le_bytes()[..2],
                            command.header()[..2]
                        );

                        Ok(command::Response::Generic)
                    }
                    GetProperty(_property) => {
                        assert_eq!(packet.tag, command::ResponseTag::GetProperty);
                        assert!(!packet.parameters.is_empty());
                        Ok(command::Response::GetProperty(packet.parameters))
                    }
                    _ => todo!(),
                }
            }

            // case 2: command data phases
            (command, _tag, command::DataPhase::CommandData(data)) => {
                let packet = ResponsePacket::try_from(initial_response)?;

                // for SetKey, LHS is true, whereas for WriteMemory, it is not (unexpectedly?)
                // assert_eq!(packet.has_data, command.data_phase().has_command_data());
                assert!(packet.status.is_none());
                match command.clone() {
                    command::Command::Keystore(command::KeystoreOperation::SetKey {
                        key: _,
                        data: _,
                    }) => {
                        // todo: can we use bigger chunks?
                        for chunk in data.chunks(32) {
                            // // TODO: somewhere in here, should "peek" a read to see if device sent
                            // // an abort (i.e. a generic response)
                            // let mut minibuf = [0u8; 4];
                            // dbg!(self.device.read_timeout(&mut minibuf, 1000).unwrap());
                            //
                            // I guess the device would just ignore our sent data if it were
                            // unhappy, so we'd find out after the fact. Although maybe sending
                            // might block?

                            let mut data_packet = vec![
                                command::ReportId::CommandData as u8,
                                0,
                                chunk.len() as u8,
                                0,
                            ];
                            data_packet.extend_from_slice(chunk);
                            data_packet.resize(4 + 32, 0);
                            trace!("--> {}", hex_str!(&data_packet, 4));
                            self.write(data_packet.as_slice())?;
                        }

                        let packet = ResponsePacket::try_from(self.read_packet()?)?;
                        assert!(!packet.has_data);
                        if let Some(status) = packet.status {
                            panic!("unexpected status {:?}", &status);
                        }
                        // assert!(packet.status.is_none());

                        assert_eq!(packet.tag, command::ResponseTag::Generic);
                        // general property of generic responses: 2 parameters, status and mirrored command header
                        assert_eq!(packet.parameters.len(), 1);
                        // it seems the device "forgets" about the parameters the original command
                        // contained (address + length)
                        // ooorrr, Table 4-11 ("The Command tag parameter identifies the response to the command sent by the host.")
                        // just means that the command tag is set
                        //
                        // UPDATE: doesn't even reflect the second byte (has-data flag)
                        // e.g.: we send: 15010003, we get back: 15000000
                        assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);

                        Ok(command::Response::Generic)
                    }
                    command::Command::WriteMemory {
                        address: _,
                        data: _,
                    }
                    | command::Command::WriteMemoryWords { .. } => {
                        for chunk in data.chunks(32) {
                            let mut data_packet = vec![
                                command::ReportId::CommandData as u8,
                                0,
                                chunk.len() as u8,
                                0,
                            ];
                            data_packet.extend_from_slice(chunk);
                            data_packet.resize(4 + 32, 0);
                            trace!("--> {}", hex_str!(&data_packet, 4));
                            self.write(data_packet.as_slice())?;
                        }

                        let packet = ResponsePacket::try_from(self.read_packet()?)?;
                        assert!(!packet.has_data);
                        if let Some(status) = packet.status {
                            panic!("unexpected status {:?}", &status);
                        }
                        assert_eq!(packet.tag, command::ResponseTag::Generic);
                        assert_eq!(packet.parameters.len(), 1);
                        assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);

                        Ok(command::Response::Generic)
                    }
                    command::Command::ReceiveSbFile { data } => {
                        let mut position: usize = 0;
                        for chunk in data.chunks(32) {
                            position += 32;
                            let _ = progress.map(|progress| progress(position));
                            let mut data_packet = vec![
                                command::ReportId::CommandData as u8,
                                0,
                                chunk.len() as u8,
                                0,
                            ];
                            data_packet.extend_from_slice(chunk);
                            data_packet.resize(4 + 32, 0);
                            trace!("--> {}", hex_str!(&data_packet, 4));
                            self.write(data_packet.as_slice())?;
                            // let packet = self.read_packet().unwrap();
                            // let what = self.device.read_timeout(&mut [], 0).unwrap();
                        }

                        let packet = ResponsePacket::try_from(match self.read_packet() {
                            Err(Error::AbortDataPhase) => {
                                println!("aborting");
                                self.read_packet().unwrap()
                            }
                            x => x?,
                        })?;
                        // let packet = ResponsePacket::try_from(self.read_packet()?)?;
                        assert!(!packet.has_data);
                        if let Some(status) = packet.status {
                            panic!("unexpected status {:?}", &status);
                        }
                        assert_eq!(packet.tag, command::ResponseTag::Generic);
                        assert_eq!(packet.parameters.len(), 1);
                        assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);

                        Ok(command::Response::Generic)
                    }
                    _ => todo!(),
                }
            }

            // case 3: reponse data phases
            (command::Command::Keystore(command::KeystoreOperation::ReadKeystore), _, _) => {
                let _packet = ResponsePacket::try_from(initial_response)?;

                let mut data = Vec::new();
                let length = 3 * 512;
                while data.len() < length {
                    let partial_data: Vec<u8> = self.read_packet()?.try_into()?;
                    assert!(data.len() + partial_data.len() <= length);
                    data.extend_from_slice(&partial_data);
                }

                let packet = ResponsePacket::try_from(self.read_packet()?)?;
                assert_eq!(packet.parameters[0].to_le_bytes()[0], command.header()[0]);

                debug!("read {} in total", data.len());
                Ok(command::Response::Data(data))
            }

            (command::Command::ReadMemory { address: _, length }, _, _) => {
                let packet = ResponsePacket::try_from(initial_response)?;
                // assert_eq!([0x03, 0x00, 0x0C, 0x00], &initial_generic_response[..4]);
                assert!(packet.has_data);
                assert!(packet.status.is_none());
                assert_eq!(packet.tag, command::ResponseTag::ReadMemory);

                // ReadMemory response: 2 parameters, status and then number of bytes to be
                // sent in data phase
                assert_eq!(packet.parameters.len(), 1);
                assert_eq!(packet.parameters[0] as usize, length);

                let mut data = Vec::new();
                while data.len() < length {
                    let partial_data: Vec<u8> = self.read_packet()?.try_into()?;
                    assert!(data.len() + partial_data.len() <= length);
                    data.extend_from_slice(&partial_data);
                }

                let packet = ResponsePacket::try_from(self.read_packet()?)?;
                assert!(!packet.has_data);
                assert!(packet.status.is_none());

                assert_eq!(packet.tag, command::ResponseTag::Generic);
                // general property of generic responses: 2 parameters, status and mirrored command header
                assert_eq!(packet.parameters.len(), 1);
                // it seems the device "forgets" about the parameters the original command
                // contained (address + length)
                // ooorrr, Table 4-11 ("The Command tag parameter identifies the response to the command sent by the host.")
                // just means that the command tag is set
                assert_eq!(
                    packet.parameters[0].to_le_bytes()[..2],
                    command.header()[..2]
                );

                Ok(command::Response::ReadMemory(data))
            }
            _ => todo!(),
        }
    }

    pub fn read_packet(&self) -> Result<ReceivedPacket> {
        // read data with timeout
        let mut data = Vec::new();
        data.resize(256, 0);
        let read = self.device.read_timeout(&mut data, READ_TIMEOUT)?;
        data.resize(read, 0);

        let report_id = command::ReportId::try_from(data[0]).map_err(Error::InvalidReportId)?;

        // the device often sends "extra junk"; we split this off early
        let expected_packet_len = u16::from_le_bytes(data[2..4].try_into().unwrap()) as usize;
        data.resize(4 + expected_packet_len, 0);
        trace!("--> {} ({}B)", hex_str!(&data, 4), data.len());

        let response_packet = data.split_off(4);

        // now handle the response packet
        Ok(match report_id {
            command::ReportId::Response => {
                // NB: this can be  "short" answer (just `03 00 00 00`), which means an
                // "AbortDataPhase".
                // In this case, need to pull naother response to get the error.
                if response_packet.is_empty() {
                    return Err(Error::AbortDataPhase);
                }
                let tag = command::ResponseTag::try_from(response_packet[0])
                    .map_err(Error::UnknownResponseTag)?;
                let has_data = (response_packet[1] & 1) != 0;
                let expected_param_count = response_packet[3] as usize;

                let mut parameters: Vec<u32> = response_packet[4..]
                    .chunks(4)
                    .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
                    .collect();
                assert_eq!(expected_param_count, parameters.len());

                // first parameter is always status
                let status_code = parameters.remove(0);
                let status = match status_code {
                    0 => None,
                    code => Some(BootloaderError::from(code)),
                };

                // NB: this is only true for Generic responses
                // // second parameter is always mirrored command header
                // let mirrored_command_header = parameters.remove(0).to_le_bytes();

                ReceivedPacket::Response(ResponsePacket {
                    tag,
                    has_data,
                    status,
                    // mirrored_command_header,
                    parameters,
                })
            }
            command::ReportId::ResponseData => ReceivedPacket::Data(response_packet),
            _ => todo!(),
        })
    }

    pub fn write(&self, data: &[u8]) -> Result<()> {
        let sent = self.device.write(data)?;
        let all = data.len();
        if sent >= all {
            Ok(())
        } else {
            Err(hidapi::HidError::IncompleteSendError { sent, all }.into())
        }
    }

    pub fn read_timeout(&self, timeout: usize) -> HidResult<Vec<u8>> {
        let mut data = Vec::new();
        data.resize(256, 0);
        let read = self.device.read_timeout(&mut data, timeout as i32)?;
        data.resize(read, 0);
        Ok(data)
    }
}

impl Protocol {
    pub fn new(device: HidDevice) -> Self {
        Self { device }
    }
}

impl std::fmt::Debug for Protocol {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Device")
            // .debug_struct("HidDevice")
            .field("manufacturer", &self.device.get_manufacturer_string())
            .field("product", &self.device.get_product_string())
            .field("serial number", &self.device.get_serial_number_string())
            // .finish()
            .finish()
    }
}