qir_stdlib/
output_recording.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use std::{
5    ffi::{CStr, CString, c_char, c_double},
6    fmt::Display,
7    io::{Read, Write},
8};
9
10use crate::strings::double_to_string;
11
12#[cfg(windows)]
13pub const LINE_ENDING: &[u8] = b"\r\n";
14#[cfg(not(windows))]
15pub const LINE_ENDING: &[u8] = b"\n";
16
17/// Holds output messages from calls to the QIR
18/// output recording functions and message calls.
19pub struct OutputRecorder {
20    buffer: Vec<u8>,
21    use_std_out: bool,
22}
23
24impl Write for OutputRecorder {
25    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
26        if self.use_std_out {
27            std::io::stdout().write(buf)
28        } else {
29            self.buffer.write(buf)
30        }
31    }
32
33    fn flush(&mut self) -> std::io::Result<()> {
34        if self.use_std_out {
35            std::io::stdout().flush()
36        } else {
37            self.buffer.flush()
38        }
39    }
40}
41
42impl Read for OutputRecorder {
43    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
44        self.buffer.as_slice().read(buf)
45    }
46}
47
48impl Default for OutputRecorder {
49    fn default() -> Self {
50        OutputRecorder {
51            buffer: Vec::new(),
52            use_std_out: true,
53        }
54    }
55}
56
57impl OutputRecorder {
58    /// Sets whether the output should be written to stdout
59    /// or stored in the buffer.
60    pub fn use_std_out(&mut self, use_std_out: bool) {
61        self.use_std_out = use_std_out;
62    }
63
64    /// Writes the newline char(s) to the output.
65    pub fn write_newline(&mut self) {
66        self.write_all(LINE_ENDING).expect("Failed to write output");
67    }
68
69    /// Drains the buffer and returns the contents.
70    pub fn drain(&mut self) -> std::vec::Drain<u8> {
71        self.buffer.drain(..)
72    }
73}
74
75thread_local! {
76    pub static OUTPUT: std::cell::RefCell<Box<OutputRecorder>> = std::cell::RefCell::new(Box::default());
77}
78
79/// Records a string to the output.
80/// # Errors
81/// Returns an error if the write fails.
82pub fn record_output_str(val: &str) -> std::io::Result<()> {
83    OUTPUT.with(|output| {
84        let mut output = output.borrow_mut();
85        output
86            .write_all(val.as_bytes())
87            .expect("Failed to write output");
88        output.write_newline();
89    });
90    Ok(())
91}
92
93/// Records a value to the output.
94/// # Errors
95/// Returns an error if the write fails.
96pub unsafe fn record_output(ty: &str, val: &dyn Display, tag: *mut c_char) -> std::io::Result<()> {
97    unsafe {
98        OUTPUT.with(|output| {
99            let mut output = output.borrow_mut();
100            output
101                .write_fmt(format_args!("OUTPUT\t{ty}\t{val}"))
102                .expect("Failed to write output");
103            if !tag.is_null() {
104                output.write_all(b"\t").expect("Failed to write output");
105                output
106                    .write_all(CStr::from_ptr(tag).to_bytes())
107                    .expect("Failed to write output");
108            }
109            output.write_newline();
110        });
111        Ok(())
112    }
113}
114
115/// Inserts a marker in the generated output that indicates the
116/// start of an array and how many array elements it has. The second
117/// parameter defines a string label for the array. Depending on
118/// the output schema, the label is included in the output or omitted.
119#[unsafe(no_mangle)]
120pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64, tag: *mut c_char) {
121    unsafe {
122        record_output("ARRAY", &val, tag).expect("Failed to write array output");
123    }
124}
125
126/// Inserts a marker in the generated output that indicates the
127/// start of a tuple and how many tuple elements it has. The second
128/// parameter defines a string label for the tuple. Depending on
129/// the output schema, the label is included in the output or omitted.
130#[unsafe(no_mangle)]
131pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64, tag: *mut c_char) {
132    unsafe {
133        record_output("TUPLE", &val, tag).expect("Failed to write tuple output");
134    }
135}
136
137#[unsafe(no_mangle)]
138pub unsafe extern "C" fn __quantum__rt__int_record_output(val: i64, tag: *mut c_char) {
139    unsafe {
140        record_output("INT", &val, tag).expect("Failed to write int output");
141    }
142}
143
144#[unsafe(no_mangle)]
145pub unsafe extern "C" fn __quantum__rt__double_record_output(val: c_double, tag: *mut c_char) {
146    unsafe {
147        record_output("DOUBLE", &double_to_string(val), tag)
148            .expect("Failed to write double output");
149    }
150}
151
152#[unsafe(no_mangle)]
153pub unsafe extern "C" fn __quantum__rt__bool_record_output(val: bool, tag: *mut c_char) {
154    unsafe {
155        record_output("BOOL", &val, tag).expect("Failed to write bool output");
156    }
157}
158
159#[unsafe(no_mangle)]
160pub unsafe extern "C" fn __quantum__rt__message_record_output(str: *const CString) {
161    unsafe {
162        record_output_str(&format!(
163            "INFO\t{}",
164            (*str)
165                .to_str()
166                .expect("Unable to convert input string")
167                .escape_default()
168        ))
169        .expect("Failed to write message output");
170    }
171}
172
173pub mod legacy {
174    use std::{ffi::c_double, ptr::null_mut};
175
176    use crate::strings::double_to_string;
177
178    use super::record_output_str;
179
180    #[allow(non_snake_case)]
181    pub extern "C" fn __quantum__rt__array_start_record_output() {
182        record_output_str("RESULT\tARRAY_START").expect("Failed to write array start output");
183    }
184
185    #[allow(non_snake_case)]
186    pub extern "C" fn __quantum__rt__array_end_record_output() {
187        record_output_str("RESULT\tARRAY_END").expect("Failed to write array end output");
188    }
189
190    #[allow(non_snake_case)]
191    pub extern "C" fn __quantum__rt__tuple_start_record_output() {
192        record_output_str("RESULT\tTUPLE_START").expect("Failed to write tuple start output");
193    }
194
195    #[allow(non_snake_case)]
196    pub extern "C" fn __quantum__rt__tuple_end_record_output() {
197        record_output_str("RESULT\tTUPLE_END").expect("Failed to write tuple end output");
198    }
199
200    #[allow(non_snake_case)]
201    pub extern "C" fn __quantum__rt__int_record_output(val: i64) {
202        record_output_str(&format!("RESULT\t{val}")).expect("Failed to write int output");
203    }
204
205    #[allow(non_snake_case)]
206    pub extern "C" fn __quantum__rt__double_record_output(val: c_double) {
207        record_output_str(&format!("RESULT\t{}", double_to_string(val)))
208            .expect("Failed to write double output");
209    }
210
211    #[allow(non_snake_case)]
212    pub extern "C" fn __quantum__rt__bool_record_output(val: bool) {
213        record_output_str(&format!("RESULT\t{val}")).expect("Failed to write bool output");
214    }
215
216    #[allow(non_snake_case)]
217    pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64) {
218        unsafe {
219            super::__quantum__rt__array_record_output(val, null_mut());
220        }
221    }
222
223    #[allow(non_snake_case)]
224    pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64) {
225        unsafe {
226            super::__quantum__rt__tuple_record_output(val, null_mut());
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use std::ptr::null_mut;
234
235    use super::*;
236
237    #[test]
238    fn test_output_int_untagged() {
239        let val: i64 = 42;
240        assert_untagged_output_match("INT", &val, "OUTPUT\tINT\t42");
241    }
242    #[test]
243    fn test_output_double_untagged() {
244        let val: f64 = 42.4533;
245        let double_str = double_to_string(val);
246        assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.4533");
247    }
248    #[test]
249    fn test_output_double_whole_untagged() {
250        let val: c_double = 42.000_000_000_000_001;
251        let double_str = double_to_string(val);
252        assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.0");
253    }
254    #[test]
255    fn test_output_bool_true_untagged() {
256        let val: bool = true;
257        assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\ttrue");
258    }
259    #[test]
260    fn test_output_bool_false_untagged() {
261        let val: bool = false;
262        assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\tfalse");
263    }
264    #[test]
265    fn test_output_tuple_untagged() {
266        let val: i64 = 42;
267        assert_untagged_output_match("TUPLE", &val, "OUTPUT\tTUPLE\t42");
268    }
269    #[test]
270    fn test_output_array_untagged() {
271        let val: i64 = 42;
272        assert_untagged_output_match("ARRAY", &val, "OUTPUT\tARRAY\t42");
273    }
274    #[test]
275    fn test_output_bool_true_tagged_from_cstring() {
276        let val: bool = true;
277        let tag = CString::new("YEEHAW").unwrap().into_raw();
278        assert_output_match("BOOL", &val, tag, "OUTPUT\tBOOL\ttrue\tYEEHAW");
279        // Avoid memory leak
280        unsafe {
281            let _ = CString::from_raw(tag);
282        }
283    }
284    #[test]
285    fn test_output_bool_true_tagged_not_from_cstring() {
286        let val: bool = true;
287        let mut tag: [c_char; 3] = [0x68, 0x69, 0];
288        // With any luck, this will segfault if the tag pointer is incorrectly
289        // passed to CString::from_raw(). (Thankfully, it does on my system.)
290        assert_output_match("BOOL", &val, tag.as_mut_ptr(), "OUTPUT\tBOOL\ttrue\thi");
291    }
292    fn assert_untagged_output_match(ty: &str, val: &dyn Display, expected_str: &str) {
293        assert_output_match(ty, val, null_mut(), expected_str);
294    }
295    fn assert_output_match(ty: &str, val: &dyn Display, tag: *mut c_char, expected_str: &str) {
296        OUTPUT.with(|output| output.borrow_mut().use_std_out(false));
297        unsafe {
298            record_output(ty, &val, tag).expect("Failed to write output");
299        }
300
301        let actual = OUTPUT.with(|output| {
302            let mut output = output.borrow_mut();
303            let output = output.drain();
304            get_byte_vec_as_string(output.as_slice())
305        });
306
307        OUTPUT.with(|output| output.borrow_mut().use_std_out(true));
308        let expected = get_string_with_line_ending(expected_str);
309        assert_eq!(actual, expected);
310    }
311    fn get_string_with_line_ending(value: &str) -> String {
312        let ending = get_byte_vec_as_string(LINE_ENDING);
313        value.to_owned() + ending.as_str()
314    }
315    fn get_byte_vec_as_string(out: &[u8]) -> String {
316        let val = std::str::from_utf8(out).unwrap();
317        val.to_string()
318    }
319}