qir_stdlib/
output_recording.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use std::{
5    ffi::{c_char, c_double, CStr, CString},
6    fmt::Display,
7    io::{Read, Write},
8};
9
10use crate::strings::double_to_string;
11
12#[cfg(windows)]
13pub const LINE_ENDING: &[u8] = b"\r\n";
14#[cfg(not(windows))]
15pub const LINE_ENDING: &[u8] = b"\n";
16
17/// Holds output messages from calls to the QIR
18/// output recording functions and message calls.
19pub struct OutputRecorder {
20    buffer: Vec<u8>,
21    use_std_out: bool,
22}
23
24impl Write for OutputRecorder {
25    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
26        if self.use_std_out {
27            std::io::stdout().write(buf)
28        } else {
29            self.buffer.write(buf)
30        }
31    }
32
33    fn flush(&mut self) -> std::io::Result<()> {
34        if self.use_std_out {
35            std::io::stdout().flush()
36        } else {
37            self.buffer.flush()
38        }
39    }
40}
41
42impl Read for OutputRecorder {
43    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
44        self.buffer.as_slice().read(buf)
45    }
46}
47
48impl Default for OutputRecorder {
49    fn default() -> Self {
50        OutputRecorder {
51            buffer: Vec::new(),
52            use_std_out: true,
53        }
54    }
55}
56
57impl OutputRecorder {
58    /// Sets whether the output should be written to stdout
59    /// or stored in the buffer.
60    pub fn use_std_out(&mut self, use_std_out: bool) {
61        self.use_std_out = use_std_out;
62    }
63
64    /// Writes the newline char(s) to the output.
65    pub fn write_newline(&mut self) {
66        self.write_all(LINE_ENDING).expect("Failed to write output");
67    }
68
69    /// Drains the buffer and returns the contents.
70    pub fn drain(&mut self) -> std::vec::Drain<u8> {
71        self.buffer.drain(..)
72    }
73}
74
75thread_local! {
76    pub static OUTPUT: std::cell::RefCell<Box<OutputRecorder>> = std::cell::RefCell::new(Box::default());
77}
78
79/// Records a string to the output.
80/// # Errors
81/// Returns an error if the write fails.
82pub fn record_output_str(val: &str) -> std::io::Result<()> {
83    OUTPUT.with(|output| {
84        let mut output = output.borrow_mut();
85        output
86            .write_all(val.as_bytes())
87            .expect("Failed to write output");
88        output.write_newline();
89    });
90    Ok(())
91}
92
93/// Records a value to the output.
94/// # Errors
95/// Returns an error if the write fails.
96pub unsafe fn record_output(ty: &str, val: &dyn Display, tag: *mut c_char) -> std::io::Result<()> {
97    OUTPUT.with(|output| {
98        let mut output = output.borrow_mut();
99        output
100            .write_fmt(format_args!("OUTPUT\t{ty}\t{val}"))
101            .expect("Failed to write output");
102        if !tag.is_null() {
103            output.write_all(b"\t").expect("Failed to write output");
104            output
105                .write_all(CStr::from_ptr(tag).to_bytes())
106                .expect("Failed to write output");
107        }
108        output.write_newline();
109    });
110    Ok(())
111}
112
113/// Inserts a marker in the generated output that indicates the
114/// start of an array and how many array elements it has. The second
115/// parameter defines a string label for the array. Depending on
116/// the output schema, the label is included in the output or omitted.
117#[no_mangle]
118pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64, tag: *mut c_char) {
119    record_output("ARRAY", &val, tag).expect("Failed to write array output");
120}
121
122/// Inserts a marker in the generated output that indicates the
123/// start of a tuple and how many tuple elements it has. The second
124/// parameter defines a string label for the tuple. Depending on
125/// the output schema, the label is included in the output or omitted.
126#[no_mangle]
127pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64, tag: *mut c_char) {
128    record_output("TUPLE", &val, tag).expect("Failed to write tuple output");
129}
130
131#[no_mangle]
132pub unsafe extern "C" fn __quantum__rt__int_record_output(val: i64, tag: *mut c_char) {
133    record_output("INT", &val, tag).expect("Failed to write int output");
134}
135
136#[no_mangle]
137pub unsafe extern "C" fn __quantum__rt__double_record_output(val: c_double, tag: *mut c_char) {
138    record_output("DOUBLE", &double_to_string(val), tag).expect("Failed to write double output");
139}
140
141#[no_mangle]
142pub unsafe extern "C" fn __quantum__rt__bool_record_output(val: bool, tag: *mut c_char) {
143    record_output("BOOL", &val, tag).expect("Failed to write bool output");
144}
145
146#[no_mangle]
147pub unsafe extern "C" fn __quantum__rt__message_record_output(str: *const CString) {
148    record_output_str(&format!(
149        "INFO\t{}",
150        (*str)
151            .to_str()
152            .expect("Unable to convert input string")
153            .escape_default()
154    ))
155    .expect("Failed to write message output");
156}
157
158pub mod legacy {
159    use std::{ffi::c_double, ptr::null_mut};
160
161    use crate::strings::double_to_string;
162
163    use super::record_output_str;
164
165    #[allow(non_snake_case)]
166    pub extern "C" fn __quantum__rt__array_start_record_output() {
167        record_output_str("RESULT\tARRAY_START").expect("Failed to write array start output");
168    }
169
170    #[allow(non_snake_case)]
171    pub extern "C" fn __quantum__rt__array_end_record_output() {
172        record_output_str("RESULT\tARRAY_END").expect("Failed to write array end output");
173    }
174
175    #[allow(non_snake_case)]
176    pub extern "C" fn __quantum__rt__tuple_start_record_output() {
177        record_output_str("RESULT\tTUPLE_START").expect("Failed to write tuple start output");
178    }
179
180    #[allow(non_snake_case)]
181    pub extern "C" fn __quantum__rt__tuple_end_record_output() {
182        record_output_str("RESULT\tTUPLE_END").expect("Failed to write tuple end output");
183    }
184
185    #[allow(non_snake_case)]
186    pub extern "C" fn __quantum__rt__int_record_output(val: i64) {
187        record_output_str(&format!("RESULT\t{val}")).expect("Failed to write int output");
188    }
189
190    #[allow(non_snake_case)]
191    pub extern "C" fn __quantum__rt__double_record_output(val: c_double) {
192        record_output_str(&format!("RESULT\t{}", double_to_string(val)))
193            .expect("Failed to write double output");
194    }
195
196    #[allow(non_snake_case)]
197    pub extern "C" fn __quantum__rt__bool_record_output(val: bool) {
198        record_output_str(&format!("RESULT\t{val}")).expect("Failed to write bool output");
199    }
200
201    #[allow(non_snake_case)]
202    pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64) {
203        super::__quantum__rt__array_record_output(val, null_mut());
204    }
205
206    #[allow(non_snake_case)]
207    pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64) {
208        super::__quantum__rt__tuple_record_output(val, null_mut());
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use std::ptr::null_mut;
215
216    use super::*;
217
218    #[test]
219    fn test_output_int_untagged() {
220        let val: i64 = 42;
221        assert_untagged_output_match("INT", &val, "OUTPUT\tINT\t42");
222    }
223    #[test]
224    fn test_output_double_untagged() {
225        let val: f64 = 42.4533;
226        let double_str = double_to_string(val);
227        assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.4533");
228    }
229    #[test]
230    fn test_output_double_whole_untagged() {
231        let val: c_double = 42.000_000_000_000_001;
232        let double_str = double_to_string(val);
233        assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.0");
234    }
235    #[test]
236    fn test_output_bool_true_untagged() {
237        let val: bool = true;
238        assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\ttrue");
239    }
240    #[test]
241    fn test_output_bool_false_untagged() {
242        let val: bool = false;
243        assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\tfalse");
244    }
245    #[test]
246    fn test_output_tuple_untagged() {
247        let val: i64 = 42;
248        assert_untagged_output_match("TUPLE", &val, "OUTPUT\tTUPLE\t42");
249    }
250    #[test]
251    fn test_output_array_untagged() {
252        let val: i64 = 42;
253        assert_untagged_output_match("ARRAY", &val, "OUTPUT\tARRAY\t42");
254    }
255    #[test]
256    fn test_output_bool_true_tagged_from_cstring() {
257        let val: bool = true;
258        let tag = CString::new("YEEHAW").unwrap().into_raw();
259        assert_output_match("BOOL", &val, tag, "OUTPUT\tBOOL\ttrue\tYEEHAW");
260        // Avoid memory leak
261        unsafe {
262            let _ = CString::from_raw(tag);
263        }
264    }
265    #[test]
266    fn test_output_bool_true_tagged_not_from_cstring() {
267        let val: bool = true;
268        let mut tag: [c_char; 3] = [0x68, 0x69, 0];
269        // With any luck, this will segfault if the tag pointer is incorrectly
270        // passed to CString::from_raw(). (Thankfully, it does on my system.)
271        assert_output_match("BOOL", &val, tag.as_mut_ptr(), "OUTPUT\tBOOL\ttrue\thi");
272    }
273    fn assert_untagged_output_match(ty: &str, val: &dyn Display, expected_str: &str) {
274        assert_output_match(ty, val, null_mut(), expected_str);
275    }
276    fn assert_output_match(ty: &str, val: &dyn Display, tag: *mut c_char, expected_str: &str) {
277        OUTPUT.with(|output| output.borrow_mut().use_std_out(false));
278        unsafe {
279            record_output(ty, &val, tag).expect("Failed to write output");
280        }
281
282        let actual = OUTPUT.with(|output| {
283            let mut output = output.borrow_mut();
284            let output = output.drain();
285            get_byte_vec_as_string(output.as_slice())
286        });
287
288        OUTPUT.with(|output| output.borrow_mut().use_std_out(true));
289        let expected = get_string_with_line_ending(expected_str);
290        assert_eq!(actual, expected);
291    }
292    fn get_string_with_line_ending(value: &str) -> String {
293        let ending = get_byte_vec_as_string(LINE_ENDING);
294        value.to_owned() + ending.as_str()
295    }
296    fn get_byte_vec_as_string(out: &[u8]) -> String {
297        let val = std::str::from_utf8(out).unwrap();
298        val.to_string()
299    }
300}