qir_stdlib/
output_recording.rs1use std::{
5 ffi::{CStr, CString, c_char, c_double},
6 fmt::Display,
7 io::{Read, Write},
8};
9
10use crate::strings::double_to_string;
11
12#[cfg(windows)]
13pub const LINE_ENDING: &[u8] = b"\r\n";
14#[cfg(not(windows))]
15pub const LINE_ENDING: &[u8] = b"\n";
16
17pub struct OutputRecorder {
20 buffer: Vec<u8>,
21 use_std_out: bool,
22}
23
24impl Write for OutputRecorder {
25 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
26 if self.use_std_out {
27 std::io::stdout().write(buf)
28 } else {
29 self.buffer.write(buf)
30 }
31 }
32
33 fn flush(&mut self) -> std::io::Result<()> {
34 if self.use_std_out {
35 std::io::stdout().flush()
36 } else {
37 self.buffer.flush()
38 }
39 }
40}
41
42impl Read for OutputRecorder {
43 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
44 self.buffer.as_slice().read(buf)
45 }
46}
47
48impl Default for OutputRecorder {
49 fn default() -> Self {
50 OutputRecorder {
51 buffer: Vec::new(),
52 use_std_out: true,
53 }
54 }
55}
56
57impl OutputRecorder {
58 pub fn use_std_out(&mut self, use_std_out: bool) {
61 self.use_std_out = use_std_out;
62 }
63
64 pub fn write_newline(&mut self) {
66 self.write_all(LINE_ENDING).expect("Failed to write output");
67 }
68
69 pub fn drain(&mut self) -> std::vec::Drain<u8> {
71 self.buffer.drain(..)
72 }
73}
74
75thread_local! {
76 pub static OUTPUT: std::cell::RefCell<Box<OutputRecorder>> = std::cell::RefCell::new(Box::default());
77}
78
79pub fn record_output_str(val: &str) -> std::io::Result<()> {
83 OUTPUT.with(|output| {
84 let mut output = output.borrow_mut();
85 output
86 .write_all(val.as_bytes())
87 .expect("Failed to write output");
88 output.write_newline();
89 });
90 Ok(())
91}
92
93pub unsafe fn record_output(ty: &str, val: &dyn Display, tag: *mut c_char) -> std::io::Result<()> {
97 unsafe {
98 OUTPUT.with(|output| {
99 let mut output = output.borrow_mut();
100 output
101 .write_fmt(format_args!("OUTPUT\t{ty}\t{val}"))
102 .expect("Failed to write output");
103 if !tag.is_null() {
104 output.write_all(b"\t").expect("Failed to write output");
105 output
106 .write_all(CStr::from_ptr(tag).to_bytes())
107 .expect("Failed to write output");
108 }
109 output.write_newline();
110 });
111 Ok(())
112 }
113}
114
115#[unsafe(no_mangle)]
120pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64, tag: *mut c_char) {
121 unsafe {
122 record_output("ARRAY", &val, tag).expect("Failed to write array output");
123 }
124}
125
126#[unsafe(no_mangle)]
131pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64, tag: *mut c_char) {
132 unsafe {
133 record_output("TUPLE", &val, tag).expect("Failed to write tuple output");
134 }
135}
136
137#[unsafe(no_mangle)]
138pub unsafe extern "C" fn __quantum__rt__int_record_output(val: i64, tag: *mut c_char) {
139 unsafe {
140 record_output("INT", &val, tag).expect("Failed to write int output");
141 }
142}
143
144#[unsafe(no_mangle)]
145pub unsafe extern "C" fn __quantum__rt__double_record_output(val: c_double, tag: *mut c_char) {
146 unsafe {
147 record_output("DOUBLE", &double_to_string(val), tag)
148 .expect("Failed to write double output");
149 }
150}
151
152#[unsafe(no_mangle)]
153pub unsafe extern "C" fn __quantum__rt__bool_record_output(val: bool, tag: *mut c_char) {
154 unsafe {
155 record_output("BOOL", &val, tag).expect("Failed to write bool output");
156 }
157}
158
159#[unsafe(no_mangle)]
160pub unsafe extern "C" fn __quantum__rt__message_record_output(str: *const CString) {
161 unsafe {
162 record_output_str(&format!(
163 "INFO\t{}",
164 (*str)
165 .to_str()
166 .expect("Unable to convert input string")
167 .escape_default()
168 ))
169 .expect("Failed to write message output");
170 }
171}
172
173pub mod legacy {
174 use std::{ffi::c_double, ptr::null_mut};
175
176 use crate::strings::double_to_string;
177
178 use super::record_output_str;
179
180 #[allow(non_snake_case)]
181 pub extern "C" fn __quantum__rt__array_start_record_output() {
182 record_output_str("RESULT\tARRAY_START").expect("Failed to write array start output");
183 }
184
185 #[allow(non_snake_case)]
186 pub extern "C" fn __quantum__rt__array_end_record_output() {
187 record_output_str("RESULT\tARRAY_END").expect("Failed to write array end output");
188 }
189
190 #[allow(non_snake_case)]
191 pub extern "C" fn __quantum__rt__tuple_start_record_output() {
192 record_output_str("RESULT\tTUPLE_START").expect("Failed to write tuple start output");
193 }
194
195 #[allow(non_snake_case)]
196 pub extern "C" fn __quantum__rt__tuple_end_record_output() {
197 record_output_str("RESULT\tTUPLE_END").expect("Failed to write tuple end output");
198 }
199
200 #[allow(non_snake_case)]
201 pub extern "C" fn __quantum__rt__int_record_output(val: i64) {
202 record_output_str(&format!("RESULT\t{val}")).expect("Failed to write int output");
203 }
204
205 #[allow(non_snake_case)]
206 pub extern "C" fn __quantum__rt__double_record_output(val: c_double) {
207 record_output_str(&format!("RESULT\t{}", double_to_string(val)))
208 .expect("Failed to write double output");
209 }
210
211 #[allow(non_snake_case)]
212 pub extern "C" fn __quantum__rt__bool_record_output(val: bool) {
213 record_output_str(&format!("RESULT\t{val}")).expect("Failed to write bool output");
214 }
215
216 #[allow(non_snake_case)]
217 pub unsafe extern "C" fn __quantum__rt__array_record_output(val: i64) {
218 unsafe {
219 super::__quantum__rt__array_record_output(val, null_mut());
220 }
221 }
222
223 #[allow(non_snake_case)]
224 pub unsafe extern "C" fn __quantum__rt__tuple_record_output(val: i64) {
225 unsafe {
226 super::__quantum__rt__tuple_record_output(val, null_mut());
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use std::ptr::null_mut;
234
235 use super::*;
236
237 #[test]
238 fn test_output_int_untagged() {
239 let val: i64 = 42;
240 assert_untagged_output_match("INT", &val, "OUTPUT\tINT\t42");
241 }
242 #[test]
243 fn test_output_double_untagged() {
244 let val: f64 = 42.4533;
245 let double_str = double_to_string(val);
246 assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.4533");
247 }
248 #[test]
249 fn test_output_double_whole_untagged() {
250 let val: c_double = 42.000_000_000_000_001;
251 let double_str = double_to_string(val);
252 assert_untagged_output_match("DOUBLE", &double_str, "OUTPUT\tDOUBLE\t42.0");
253 }
254 #[test]
255 fn test_output_bool_true_untagged() {
256 let val: bool = true;
257 assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\ttrue");
258 }
259 #[test]
260 fn test_output_bool_false_untagged() {
261 let val: bool = false;
262 assert_untagged_output_match("BOOL", &val, "OUTPUT\tBOOL\tfalse");
263 }
264 #[test]
265 fn test_output_tuple_untagged() {
266 let val: i64 = 42;
267 assert_untagged_output_match("TUPLE", &val, "OUTPUT\tTUPLE\t42");
268 }
269 #[test]
270 fn test_output_array_untagged() {
271 let val: i64 = 42;
272 assert_untagged_output_match("ARRAY", &val, "OUTPUT\tARRAY\t42");
273 }
274 #[test]
275 fn test_output_bool_true_tagged_from_cstring() {
276 let val: bool = true;
277 let tag = CString::new("YEEHAW").unwrap().into_raw();
278 assert_output_match("BOOL", &val, tag, "OUTPUT\tBOOL\ttrue\tYEEHAW");
279 unsafe {
281 let _ = CString::from_raw(tag);
282 }
283 }
284 #[test]
285 fn test_output_bool_true_tagged_not_from_cstring() {
286 let val: bool = true;
287 let mut tag: [c_char; 3] = [0x68, 0x69, 0];
288 assert_output_match("BOOL", &val, tag.as_mut_ptr(), "OUTPUT\tBOOL\ttrue\thi");
291 }
292 fn assert_untagged_output_match(ty: &str, val: &dyn Display, expected_str: &str) {
293 assert_output_match(ty, val, null_mut(), expected_str);
294 }
295 fn assert_output_match(ty: &str, val: &dyn Display, tag: *mut c_char, expected_str: &str) {
296 OUTPUT.with(|output| output.borrow_mut().use_std_out(false));
297 unsafe {
298 record_output(ty, &val, tag).expect("Failed to write output");
299 }
300
301 let actual = OUTPUT.with(|output| {
302 let mut output = output.borrow_mut();
303 let output = output.drain();
304 get_byte_vec_as_string(output.as_slice())
305 });
306
307 OUTPUT.with(|output| output.borrow_mut().use_std_out(true));
308 let expected = get_string_with_line_ending(expected_str);
309 assert_eq!(actual, expected);
310 }
311 fn get_string_with_line_ending(value: &str) -> String {
312 let ending = get_byte_vec_as_string(LINE_ENDING);
313 value.to_owned() + ending.as_str()
314 }
315 fn get_byte_vec_as_string(out: &[u8]) -> String {
316 let val = std::str::from_utf8(out).unwrap();
317 val.to_string()
318 }
319}