qir_stdlib/
output_recording.rs1use std::{
5 ffi::{c_char, c_double, CStr, CString},
6 fmt::Display,
7 io::{Read, Write},
8};
9
10use crate::strings::double_to_string;
11
12#[cfg(windows)]
13pub const LINE_ENDING: &[u8] = b"\r\n";
14#[cfg(not(windows))]
15pub const LINE_ENDING: &[u8] = b"\n";
16
17pub struct OutputRecorder {
20 buffer: Vec<u8>,
21 use_std_out: bool,
22}
23
24impl Write for OutputRecorder {
25 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
26 if self.use_std_out {
27 std::io::stdout().write(buf)
28 } else {
29 self.buffer.write(buf)
30 }
31 }
32
33 fn flush(&mut self) -> std::io::Result<()> {
34 if self.use_std_out {
35 std::io::stdout().flush()
36 } else {
37 self.buffer.flush()
38 }
39 }
40}
41
42impl Read for OutputRecorder {
43 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
44 self.buffer.as_slice().read(buf)
45 }
46}
47
48impl Default for OutputRecorder {
49 fn default() -> Self {
50 OutputRecorder {
51 buffer: Vec::new(),
52 use_std_out: true,
53 }
54 }
55}
56
57impl OutputRecorder {
58 pub fn use_std_out(&mut self, use_std_out: bool) {
61 self.use_std_out = use_std_out;
62 }
63
64 pub fn write_newline(&mut self) {
66 self.write_all(LINE_ENDING).expect("Failed to write output");
67 }
68
69 pub fn drain(&mut self) -> std::vec::Drain<u8> {
71 self.buffer.drain(..)
72 }
73}
74
75thread_local! {
76 pub static OUTPUT: std::cell::RefCell<Box<OutputRecorder>> = std::cell::RefCell::new(Box::default());
77}
78
79pub fn record_output_str(val: &str) -> std::io::Result<()> {
83 OUTPUT.with(|output| {
84 let mut output = output.borrow_mut();
85 output
86 .write_all(val.as_bytes())
87 .expect("Failed to write output");
88 output.write_newline();
89 });
90 Ok(())
91}
92
93pub unsafe fn record_output(ty: &str, val: &dyn Display, tag: *mut c_char) -> std::io::Result<()> {
97 OUTPUT.with(|output| {
98 let mut output = output.borrow_mut();
99 output
100 .write_fmt(format_args!("OUTPUT\t{ty}\t{val}"))
101 .expect("Failed to write output");
102 if !tag.is_null() {
103 output.write_all(b"\t").expect("Failed to write output");
104 output
105 .write_all(CStr::from_ptr(tag).to_bytes())
106 .expect("Failed to write output");
107 }
108 output.write_newline();
109 });
110 Ok(())
111}
112
113#[no_mangle]
118pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64, tag: *mut c_char) {
119 record_output("ARRAY", &val, tag).expect("Failed to write array output");
120}
121
122#[no_mangle]
127pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64, tag: *mut c_char) {
128 record_output("TUPLE", &val, tag).expect("Failed to write tuple output");
129}
130
131#[no_mangle]
132pub unsafe extern "C" fn __quantum__rt__int_record_output(val: i64, tag: *mut c_char) {
133 record_output("INT", &val, tag).expect("Failed to write int output");
134}
135
136#[no_mangle]
137pub unsafe extern "C" fn __quantum__rt__double_record_output(val: c_double, tag: *mut c_char) {
138 record_output("DOUBLE", &double_to_string(val), tag).expect("Failed to write double output");
139}
140
141#[no_mangle]
142pub unsafe extern "C" fn __quantum__rt__bool_record_output(val: bool, tag: *mut c_char) {
143 record_output("BOOL", &val, tag).expect("Failed to write bool output");
144}
145
146#[no_mangle]
147pub unsafe extern "C" fn __quantum__rt__message_record_output(str: *const CString) {
148 record_output_str(&format!(
149 "INFO\t{}",
150 (*str)
151 .to_str()
152 .expect("Unable to convert input string")
153 .escape_default()
154 ))
155 .expect("Failed to write message output");
156}
157
158pub mod legacy {
159 use std::{ffi::c_double, ptr::null_mut};
160
161 use crate::strings::double_to_string;
162
163 use super::record_output_str;
164
165 #[allow(non_snake_case)]
166 pub extern "C" fn __quantum__rt__array_start_record_output() {
167 record_output_str("RESULT\tARRAY_START").expect("Failed to write array start output");
168 }
169
170 #[allow(non_snake_case)]
171 pub extern "C" fn __quantum__rt__array_end_record_output() {
172 record_output_str("RESULT\tARRAY_END").expect("Failed to write array end output");
173 }
174
175 #[allow(non_snake_case)]
176 pub extern "C" fn __quantum__rt__tuple_start_record_output() {
177 record_output_str("RESULT\tTUPLE_START").expect("Failed to write tuple start output");
178 }
179
180 #[allow(non_snake_case)]
181 pub extern "C" fn __quantum__rt__tuple_end_record_output() {
182 record_output_str("RESULT\tTUPLE_END").expect("Failed to write tuple end output");
183 }
184
185 #[allow(non_snake_case)]
186 pub extern "C" fn __quantum__rt__int_record_output(val: i64) {
187 record_output_str(&format!("RESULT\t{val}")).expect("Failed to write int output");
188 }
189
190 #[allow(non_snake_case)]
191 pub extern "C" fn __quantum__rt__double_record_output(val: c_double) {
192 record_output_str(&format!("RESULT\t{}", double_to_string(val)))
193 .expect("Failed to write double output");
194 }
195
196 #[allow(non_snake_case)]
197 pub extern "C" fn __quantum__rt__bool_record_output(val: bool) {
198 record_output_str(&format!("RESULT\t{val}")).expect("Failed to write bool output");
199 }
200
201 #[allow(non_snake_case)]
202 pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64) {
203 super::__quantum__rt__array_record_output(val, null_mut());
204 }
205
206 #[allow(non_snake_case)]
207 pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64) {
208 super::__quantum__rt__tuple_record_output(val, null_mut());
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use std::ptr::null_mut;
215
216 use super::*;
217
218 #[test]
219 fn test_output_int_untagged() {
220 let val: i64 = 42;
221 assert_untagged_output_match("INT", &val, "OUTPUT\tINT\t42");
222 }
223 #[test]
224 fn test_output_double_untagged() {
225 let val: f64 = 42.4533;
226 let double_str = double_to_string(val);
227 assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.4533");
228 }
229 #[test]
230 fn test_output_double_whole_untagged() {
231 let val: c_double = 42.000_000_000_000_001;
232 let double_str = double_to_string(val);
233 assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.0");
234 }
235 #[test]
236 fn test_output_bool_true_untagged() {
237 let val: bool = true;
238 assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\ttrue");
239 }
240 #[test]
241 fn test_output_bool_false_untagged() {
242 let val: bool = false;
243 assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\tfalse");
244 }
245 #[test]
246 fn test_output_tuple_untagged() {
247 let val: i64 = 42;
248 assert_untagged_output_match("TUPLE", &val, "OUTPUT\tTUPLE\t42");
249 }
250 #[test]
251 fn test_output_array_untagged() {
252 let val: i64 = 42;
253 assert_untagged_output_match("ARRAY", &val, "OUTPUT\tARRAY\t42");
254 }
255 #[test]
256 fn test_output_bool_true_tagged_from_cstring() {
257 let val: bool = true;
258 let tag = CString::new("YEEHAW").unwrap().into_raw();
259 assert_output_match("BOOL", &val, tag, "OUTPUT\tBOOL\ttrue\tYEEHAW");
260 unsafe {
262 let _ = CString::from_raw(tag);
263 }
264 }
265 #[test]
266 fn test_output_bool_true_tagged_not_from_cstring() {
267 let val: bool = true;
268 let mut tag: [c_char; 3] = [0x68, 0x69, 0];
269 assert_output_match("BOOL", &val, tag.as_mut_ptr(), "OUTPUT\tBOOL\ttrue\thi");
272 }
273 fn assert_untagged_output_match(ty: &str, val: &dyn Display, expected_str: &str) {
274 assert_output_match(ty, val, null_mut(), expected_str);
275 }
276 fn assert_output_match(ty: &str, val: &dyn Display, tag: *mut c_char, expected_str: &str) {
277 OUTPUT.with(|output| output.borrow_mut().use_std_out(false));
278 unsafe {
279 record_output(ty, &val, tag).expect("Failed to write output");
280 }
281
282 let actual = OUTPUT.with(|output| {
283 let mut output = output.borrow_mut();
284 let output = output.drain();
285 get_byte_vec_as_string(output.as_slice())
286 });
287
288 OUTPUT.with(|output| output.borrow_mut().use_std_out(true));
289 let expected = get_string_with_line_ending(expected_str);
290 assert_eq!(actual, expected);
291 }
292 fn get_string_with_line_ending(value: &str) -> String {
293 let ending = get_byte_vec_as_string(LINE_ENDING);
294 value.to_owned() + ending.as_str()
295 }
296 fn get_byte_vec_as_string(out: &[u8]) -> String {
297 let val = std::str::from_utf8(out).unwrap();
298 val.to_string()
299 }
300}