runner/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![deny(clippy::all, clippy::pedantic)]
5#![allow(unused)]
6
7mod cli;
8pub use cli::main;
9
10pub use qir_backend::{
11    arrays::*, bigints::*, callables::*, exp::*, math::*, output_recording::*, range_support::*,
12    result_bool::*, strings::*, tuples::*, *,
13};
14
15use inkwell::{
16    OptimizationLevel,
17    attributes::AttributeLoc,
18    context::Context,
19    execution_engine::ExecutionEngine,
20    llvm_sys::{core::LLVMCreateMemoryBufferWithMemoryRange, ir_reader::LLVMParseIRInContext},
21    memory_buffer::MemoryBuffer,
22    module::Module,
23    passes::{PassBuilderOptions, PassManager},
24    targets::{CodeModel, InitializationConfig, RelocMode, Target, TargetMachine, TargetTriple},
25    values::FunctionValue,
26};
27use std::{
28    collections::HashMap,
29    ffi::{CStr, CString, OsStr, c_char},
30    io::{Read, Write},
31    iter::once,
32    path::Path,
33    ptr::{self, NonNull, null_mut},
34};
35
36/// # Errors
37///
38/// Will return `Err` if
39/// - `filename` does not exist or the user does not have permission to read it.
40/// - `filename` does not contain a valid bitcode module
41/// - `filename` does not have either a .ll or .bc as an extension
42/// - `entry_point` is not found in the QIR
43/// - Entry point has parameters or a non-void return type.
44pub fn run_file(
45    path: impl AsRef<Path>,
46    entry_point: Option<&str>,
47    shots: u32,
48    rng_seed: Option<u64>,
49    output_writer: &mut impl Write,
50) -> Result<(), String> {
51    if let Some(seed) = rng_seed {
52        qir_backend::set_rng_seed(seed);
53    }
54    let context = Context::create();
55    let module = load_file(path, &context)?;
56    run_module(&module, entry_point, shots, output_writer)
57}
58
59/// # Errors
60///
61/// Will return `Err` if
62/// - `bytes` does not contain a valid bitcode module
63/// - `entry_point` is not found in the QIR
64/// - Entry point has parameters or a non-void return type.
65pub fn run_bitcode(
66    bytes: &[u8],
67    entry_point: Option<&str>,
68    shots: u32,
69    output_writer: &mut impl Write,
70) -> Result<(), String> {
71    run_bytes(bytes, entry_point, shots, None, output_writer)
72}
73
74/// # Errors
75///
76/// Will return `Err` if
77/// - `bytes` does not contain a valid bitcode module or LLVM IR string
78/// - `entry_point` is not found in the QIR
79/// - Entry point has parameters or a non-void return type.
80pub fn run_bytes(
81    bytes: &[u8],
82    entry_point: Option<&str>,
83    shots: u32,
84    rng_seed: Option<u64>,
85    output_writer: &mut impl Write,
86) -> Result<(), String> {
87    if let Some(seed) = rng_seed {
88        qir_backend::set_rng_seed(seed);
89    }
90
91    let context = Context::create();
92
93    // To know if the bytes are bitcode, check for both the wrapped and non-wrapped magic bytes.
94    // See the definition for llvm::isBitCode at https://llvm.org/doxygen/namespacellvm.html#ae0ccf1c0633b02c90c21118d0c1c7ec4
95    // for reference.
96    let bytes_len = bytes.len();
97    if bytes_len < 4 {
98        return Err("byte array is too short".to_string());
99    }
100    let is_bitcode =
101        bytes[0..4] == [0xDE, 0xC0, 0x17, 0x0B] || bytes[0..4] == [0x42, 0x43, 0xC0, 0xDE];
102    let bytes = if is_bitcode {
103        bytes.to_vec()
104    } else {
105        // The bytes represent LLVM IR string, so we must ensure it is null-terminated.
106        // Note that we use the original bytes length to avoid including the null terminator in the IR parsing, which would cause it to fail.
107        bytes.iter().copied().chain(once(0_u8)).collect()
108    };
109
110    let buffer = MemoryBuffer::create_from_memory_range(&bytes[0..bytes_len], Default::default());
111    context
112        .create_module_from_ir(buffer)
113        .map_err(|e| format!("Failed to parse module from IR: {}", e.to_string()))
114        .and_then(|module| run_module(&module, entry_point, shots, output_writer))
115}
116
117fn run_module(
118    module: &Module,
119    entry_point: Option<&str>,
120    shots: u32,
121    output_writer: &mut impl Write,
122) -> Result<(), String> {
123    module
124        .verify()
125        .map_err(|e| format!("Failed to verify module: {}", e.to_string()))?;
126
127    Target::initialize_native(&InitializationConfig::default())?;
128    let default_triple = TargetMachine::get_default_triple();
129    let target = Target::from_triple(&default_triple).map_err(|e| e.to_string())?;
130    if !target.has_asm_backend() {
131        return Err("Target doesn't have an ASM backend.".to_owned());
132    }
133    if !target.has_target_machine() {
134        return Err("Target doesn't have a target machine.".to_owned());
135    }
136
137    run_basic_passes_on(module, &default_triple, &target)?;
138
139    inkwell::support::load_library_permanently(Path::new(""));
140
141    let execution_engine = module
142        .create_jit_execution_engine(OptimizationLevel::None)
143        .map_err(|e| e.to_string())?;
144
145    bind_functions(module, &execution_engine)?;
146
147    let entry_point = choose_entry_point(module_functions(module), entry_point)?;
148    // TODO: need a cleaner way to get the attr strings for metadata
149    let attrs: Vec<(String, String)> = entry_point
150        .attributes(AttributeLoc::Function)
151        .iter()
152        .map(|attr| {
153            (
154                attr.get_string_kind_id()
155                    .to_str()
156                    .expect("Invalid UTF8 data")
157                    .to_string(),
158                attr.get_string_value()
159                    .to_str()
160                    .expect("Invalid UTF8 data")
161                    .to_string(),
162            )
163        })
164        .collect();
165
166    for _ in 1..=shots {
167        output_writer
168            .write_all("START\n".as_bytes())
169            .expect("Failed to write output");
170        for attr in &attrs {
171            output_writer
172                .write_all(format!("METADATA\t{}", attr.0).as_bytes())
173                .expect("Failed to write output");
174            if !attr.1.is_empty() {
175                output_writer
176                    .write_all(format!("\t{}", attr.1).as_bytes())
177                    .expect("Failed to write output");
178            }
179            output_writer
180                .write_all(qir_stdlib::output_recording::LINE_ENDING)
181                .expect("Failed to write output");
182        }
183
184        __quantum__rt__initialize(null_mut());
185        unsafe { run_entry_point(&execution_engine, entry_point)? }
186
187        // Write the saved output records to the output_writer
188        OUTPUT.with(|output| {
189            let mut output = output.borrow_mut();
190            output_writer
191                .write_all(output.drain().as_slice())
192                .expect("Failed to write output");
193        });
194
195        // Write the end of the shot
196        output_writer
197            .write_all("END\t0".as_bytes())
198            .expect("Failed to write output");
199        output_writer
200            .write_all(qir_stdlib::output_recording::LINE_ENDING)
201            .expect("Failed to write output");
202    }
203    Ok(())
204}
205
206fn load_file(path: impl AsRef<Path>, context: &Context) -> Result<Module<'_>, String> {
207    let path = path.as_ref();
208    let extension = path.extension().and_then(OsStr::to_str);
209
210    match extension {
211        Some("ll") => MemoryBuffer::create_from_file(path)
212            .and_then(|buffer| context.create_module_from_ir(buffer))
213            .map_err(|e| e.to_string()),
214        Some("bc") => Module::parse_bitcode_from_path(path, context).map_err(|e| e.to_string()),
215        _ => Err(format!("Unsupported file extension '{extension:?}'.")),
216    }
217}
218
219unsafe fn run_entry_point(
220    execution_engine: &ExecutionEngine,
221    entry_point: FunctionValue,
222) -> Result<(), String> {
223    unsafe {
224        if entry_point.count_params() == 0 {
225            execution_engine.run_function(entry_point, &[]);
226            Ok(())
227        } else {
228            Err("Entry point has parameters or a non-void return type.".to_owned())
229        }
230    }
231}
232
233fn choose_entry_point<'ctx>(
234    functions: impl Iterator<Item = FunctionValue<'ctx>>,
235    name: Option<&str>,
236) -> Result<FunctionValue<'ctx>, String> {
237    let mut entry_points = functions
238        .filter(|f| is_entry_point(*f) && name.iter().all(|n| f.get_name().to_str() == Ok(n)));
239
240    let entry_point = entry_points
241        .next()
242        .ok_or_else(|| "No matching entry point found.".to_owned())?;
243
244    if entry_points.next().is_some() {
245        Err("Multiple matching entry points found.".to_owned())
246    } else {
247        Ok(entry_point)
248    }
249}
250
251fn module_functions<'ctx>(module: &Module<'ctx>) -> impl Iterator<Item = FunctionValue<'ctx>> {
252    struct FunctionValueIter<'ctx>(Option<FunctionValue<'ctx>>);
253
254    impl<'ctx> Iterator for FunctionValueIter<'ctx> {
255        type Item = FunctionValue<'ctx>;
256
257        fn next(&mut self) -> Option<Self::Item> {
258            let function = self.0;
259            self.0 = function.and_then(inkwell::values::FunctionValue::get_next_function);
260            function
261        }
262    }
263
264    FunctionValueIter(module.get_first_function())
265}
266
267fn is_entry_point(function: FunctionValue) -> bool {
268    function
269        .get_string_attribute(AttributeLoc::Function, "entry_point")
270        .is_some()
271        || function
272            .get_string_attribute(AttributeLoc::Function, "EntryPoint")
273            .is_some()
274}
275
276fn run_basic_passes_on(
277    module: &Module,
278    target_triple: &TargetTriple,
279    target: &Target,
280) -> Result<(), String> {
281    // Description of this syntax:
282    // https://github.com/llvm/llvm-project/blob/2ba08386156ef25913b1bee170d8fe95aaceb234/llvm/include/llvm/Passes/PassBuilder.h#L308-L347
283    const BASIC_PASS_PIPELINE: &str = "globaldce,strip-dead-prototypes";
284
285    // Boilerplate taken from here:
286    // https://github.com/TheDan64/inkwell/blob/5c9f7fcbb0a667f7391b94beb65f1a670ad13221/examples/kaleidoscope/main.rs#L86-L95
287    let target_machine = target
288        .create_target_machine(
289            target_triple,
290            "generic",
291            "",
292            OptimizationLevel::None,
293            RelocMode::Default,
294            CodeModel::Default,
295        )
296        .ok_or("Unable to create TargetMachine from Target")?;
297    module
298        .run_passes(
299            BASIC_PASS_PIPELINE,
300            &target_machine,
301            PassBuilderOptions::create(),
302        )
303        .map_err(|e| e.to_string())
304}
305
306#[allow(clippy::too_many_lines)]
307fn bind_functions(module: &Module, execution_engine: &ExecutionEngine) -> Result<(), String> {
308    let mut uses_legacy = vec![];
309    let mut declarations: HashMap<String, FunctionValue> = HashMap::default();
310    for func in module_functions(module).filter(|f| {
311        f.count_basic_blocks() == 0
312            && !f
313                .get_name()
314                .to_str()
315                .expect("Unable to coerce function name into str.")
316                .starts_with("llvm.")
317    }) {
318        declarations.insert(
319            func.get_name()
320                .to_str()
321                .expect("Unable to coerce function name into str.")
322                .to_owned(),
323            func,
324        );
325    }
326
327    macro_rules! bind {
328        ($func:ident, $param_count:expr) => {
329            if let Some(func) = declarations.get(stringify!($func)) {
330                if func.get_params().len() != $param_count {
331                    return Err(format!(
332                        "Function '{}' has mismatched parameters: expected {}, found {}",
333                        stringify!($func),
334                        $param_count,
335                        func.get_params().len()
336                    ));
337                }
338                execution_engine.add_global_mapping(func, $func as *const () as usize);
339                declarations.remove(stringify!($func));
340            }
341        };
342    }
343
344    macro_rules! legacy_output {
345        ($func:ident) => {
346            if let Some(func) = declarations.get(stringify!($func)) {
347                execution_engine.add_global_mapping(
348                    func,
349                    qir_backend::output_recording::legacy::$func as *const () as usize,
350                );
351                declarations.remove(stringify!($func));
352                Some(true)
353            } else {
354                None
355            }
356        };
357    }
358
359    macro_rules! bind_output_record {
360        ($func:ident) => {
361            if let Some(func) = declarations.get(stringify!($func)) {
362                if func.get_params().len() == 1 {
363                    execution_engine.add_global_mapping(
364                        func,
365                        qir_backend::output_recording::legacy::$func as *const () as usize,
366                    );
367                    declarations.remove(stringify!($func));
368                    Some(true)
369                } else {
370                    execution_engine.add_global_mapping(func, $func as *const () as usize);
371                    declarations.remove(stringify!($func));
372                    Some(false)
373                }
374            } else {
375                None
376            }
377        };
378    }
379
380    // Legacy output methods
381    uses_legacy.push(legacy_output!(__quantum__rt__array_end_record_output));
382    uses_legacy.push(legacy_output!(__quantum__rt__array_start_record_output));
383    uses_legacy.push(legacy_output!(__quantum__rt__tuple_end_record_output));
384    uses_legacy.push(legacy_output!(__quantum__rt__tuple_start_record_output));
385
386    bind!(__quantum__rt__initialize, 1);
387    bind!(__quantum__qis__arccos__body, 1);
388    bind!(__quantum__qis__arcsin__body, 1);
389    bind!(__quantum__qis__arctan__body, 1);
390    bind!(__quantum__qis__arctan2__body, 2);
391    bind!(__quantum__qis__assertmeasurementprobability__body, 6);
392    bind!(__quantum__qis__assertmeasurementprobability__ctl, 6);
393    bind!(__quantum__qis__barrier__body, 0);
394    bind!(__quantum__qis__ccx__body, 3);
395    bind!(__quantum__qis__cnot__body, 2);
396    bind!(__quantum__qis__cos__body, 1);
397    bind!(__quantum__qis__cosh__body, 1);
398    bind!(__quantum__qis__cx__body, 2);
399    bind!(__quantum__qis__cz__body, 2);
400    bind!(__quantum__qis__drawrandomdouble__body, 2);
401    bind!(__quantum__qis__drawrandomint__body, 2);
402    bind!(__quantum__qis__dumpmachine__body, 1);
403    bind!(__quantum__qis__exp__body, 3);
404    bind!(__quantum__qis__exp__adj, 3);
405    bind!(__quantum__qis__exp__ctl, 2);
406    bind!(__quantum__qis__exp__ctladj, 2);
407    bind!(__quantum__qis__h__body, 1);
408    bind!(__quantum__qis__h__ctl, 2);
409    bind!(__quantum__qis__ieeeremainder__body, 2);
410    bind!(__quantum__qis__infinity__body, 0);
411    bind!(__quantum__qis__isinf__body, 1);
412    bind!(__quantum__qis__isnan__body, 1);
413    bind!(__quantum__qis__isnegativeinfinity__body, 1);
414    bind!(__quantum__qis__log__body, 1);
415    bind!(__quantum__qis__measure__body, 2);
416    bind!(__quantum__qis__mresetz__body, 2);
417    bind!(__quantum__qis__mz__body, 2);
418    bind!(__quantum__qis__nan__body, 0);
419    bind!(__quantum__qis__r__adj, 3);
420    bind!(__quantum__qis__r__body, 3);
421    bind!(__quantum__qis__r__ctl, 2);
422    bind!(__quantum__qis__r__ctladj, 2);
423    bind!(__quantum__qis__read_result__body, 1);
424    bind!(__quantum__qis__reset__body, 1);
425    bind!(__quantum__qis__rx__body, 2);
426    bind!(__quantum__qis__rx__ctl, 2);
427    bind!(__quantum__qis__rxx__body, 3);
428    bind!(__quantum__qis__ry__body, 2);
429    bind!(__quantum__qis__ry__ctl, 2);
430    bind!(__quantum__qis__ryy__body, 3);
431    bind!(__quantum__qis__rz__body, 2);
432    bind!(__quantum__qis__rz__ctl, 2);
433    bind!(__quantum__qis__rzz__body, 3);
434    bind!(__quantum__qis__s__adj, 1);
435    bind!(__quantum__qis__s__body, 1);
436    bind!(__quantum__qis__s__ctl, 2);
437    bind!(__quantum__qis__s__ctladj, 2);
438    bind!(__quantum__qis__sx__body, 1);
439    bind!(__quantum__qis__sin__body, 1);
440    bind!(__quantum__qis__sinh__body, 1);
441    bind!(__quantum__qis__sqrt__body, 1);
442    bind!(__quantum__qis__swap__body, 2);
443    bind!(__quantum__qis__t__adj, 1);
444    bind!(__quantum__qis__t__body, 1);
445    bind!(__quantum__qis__t__ctl, 2);
446    bind!(__quantum__qis__t__ctladj, 2);
447    bind!(__quantum__qis__tan__body, 1);
448    bind!(__quantum__qis__tanh__body, 1);
449    bind!(__quantum__qis__x__body, 1);
450    bind!(__quantum__qis__x__ctl, 2);
451    bind!(__quantum__qis__y__body, 1);
452    bind!(__quantum__qis__y__ctl, 2);
453    bind!(__quantum__qis__z__body, 1);
454    bind!(__quantum__qis__z__ctl, 2);
455    bind!(__quantum__rt__array_concatenate, 2);
456    bind!(__quantum__rt__array_copy, 2);
457    bind!(__quantum__rt__array_create_1d, 2);
458
459    // New calls
460    bind!(__quantum__rt__array_record_output, 2);
461    bind!(__quantum__rt__tuple_record_output, 2);
462
463    // calls with unlabeled signature variants
464    uses_legacy.push(bind_output_record!(__quantum__rt__bool_record_output));
465    uses_legacy.push(bind_output_record!(__quantum__rt__double_record_output));
466    uses_legacy.push(bind_output_record!(__quantum__rt__int_record_output));
467
468    // results need special handling as they aren't in the std lib
469    uses_legacy.push(
470        if let Some(func) = declarations.get("__quantum__rt__result_record_output") {
471            if func.get_params().len() == 1 {
472                execution_engine.add_global_mapping(
473                    func,
474                    qir_backend::legacy_output::__quantum__rt__result_record_output as *const ()
475                        as usize,
476                );
477                declarations.remove("__quantum__rt__result_record_output");
478                Some(true)
479            } else {
480                execution_engine.add_global_mapping(
481                    func,
482                    __quantum__rt__result_record_output as *const () as usize,
483                );
484                declarations.remove("__quantum__rt__result_record_output");
485                Some(false)
486            }
487        } else {
488            None
489        },
490    );
491
492    // calls to __quantum__qis__m__body may use either dynamic or static results, so bind to the right
493    // implementation based on number of arguments.
494    if let Some(func) = declarations.get("__quantum__qis__m__body") {
495        if func.get_params().len() == 2 {
496            execution_engine.add_global_mapping(
497                func,
498                qir_backend::__quantum__qis__mz__body as *const () as usize,
499            );
500        } else if func.get_params().len() == 1 {
501            execution_engine.add_global_mapping(
502                func,
503                qir_backend::__quantum__qis__m__body as *const () as usize,
504            );
505        } else {
506            return Err(format!(
507                "Function '__quantum__qis__m__body' has mismatched parameters: expected 1 or 2, found {}",
508                func.get_params().len()
509            ));
510        }
511        declarations.remove("__quantum__qis__m__body");
512    }
513
514    bind!(__quantum__rt__array_get_element_ptr_1d, 2);
515    bind!(__quantum__rt__array_get_size_1d, 1);
516    bind!(quantum__rt__array_slice_1d, 3);
517    bind!(__quantum__rt__array_update_alias_count, 2);
518    bind!(__quantum__rt__array_update_reference_count, 2);
519    bind!(__quantum__rt__bigint_add, 2);
520    bind!(__quantum__rt__bigint_bitand, 2);
521    bind!(__quantum__rt__bigint_bitnot, 1);
522    bind!(__quantum__rt__bigint_bitor, 2);
523    bind!(__quantum__rt__bigint_bitxor, 2);
524    bind!(__quantum__rt__bigint_create_array, 2);
525    bind!(__quantum__rt__bigint_create_i64, 1);
526    bind!(__quantum__rt__bigint_divide, 2);
527    bind!(__quantum__rt__bigint_equal, 2);
528    bind!(__quantum__rt__bigint_get_data, 1);
529    bind!(__quantum__rt__bigint_get_length, 1);
530    bind!(__quantum__rt__bigint_greater, 2);
531    bind!(__quantum__rt__bigint_greater_eq, 2);
532    bind!(__quantum__rt__bigint_modulus, 2);
533    bind!(__quantum__rt__bigint_multiply, 2);
534    bind!(__quantum__rt__bigint_negate, 1);
535    bind!(__quantum__rt__bigint_power, 2);
536    bind!(__quantum__rt__bigint_shiftleft, 2);
537    bind!(__quantum__rt__bigint_shiftright, 2);
538    bind!(__quantum__rt__bigint_subtract, 2);
539    bind!(__quantum__rt__bigint_to_string, 1);
540    bind!(__quantum__rt__bigint_update_reference_count, 2);
541    bind!(__quantum__rt__bool_to_string, 1);
542    bind!(__quantum__rt__callable_copy, 2);
543    bind!(__quantum__rt__callable_create, 3);
544    bind!(__quantum__rt__callable_invoke, 3);
545    bind!(__quantum__rt__callable_make_adjoint, 1);
546    bind!(__quantum__rt__callable_make_controlled, 1);
547    bind!(__quantum__rt__callable_update_alias_count, 2);
548    bind!(__quantum__rt__callable_update_reference_count, 2);
549    bind!(__quantum__rt__capture_update_alias_count, 2);
550    bind!(__quantum__rt__capture_update_reference_count, 2);
551    bind!(__quantum__rt__double_to_string, 1);
552    bind!(__quantum__rt__fail, 1);
553    bind!(__quantum__rt__int_to_string, 1);
554    bind!(__quantum__rt__memory_allocate, 1);
555    bind!(__quantum__rt__message, 1);
556    bind!(__quantum__rt__pauli_to_string, 1);
557    bind!(__quantum__rt__qubit_allocate, 0);
558    bind!(__quantum__rt__qubit_allocate_array, 1);
559    bind!(__quantum__rt__qubit_release, 1);
560    bind!(__quantum__rt__qubit_release_array, 1);
561    bind!(__quantum__rt__qubit_to_string, 1);
562    bind!(__quantum__rt__read_result, 1);
563    bind!(__quantum__rt__result_equal, 2);
564    bind!(quantum__rt__range_to_string, 1);
565    bind!(__quantum__rt__result_get_one, 0);
566    bind!(__quantum__rt__result_get_zero, 0);
567    bind!(__quantum__rt__result_to_string, 1);
568    bind!(__quantum__rt__result_update_reference_count, 2);
569    bind!(__quantum__rt__string_concatenate, 2);
570    bind!(__quantum__rt__string_create, 1);
571    bind!(__quantum__rt__string_equal, 2);
572    bind!(__quantum__rt__string_get_data, 1);
573    bind!(__quantum__rt__string_get_length, 1);
574    bind!(__quantum__rt__string_update_reference_count, 2);
575    bind!(__quantum__rt__tuple_copy, 2);
576    bind!(__quantum__rt__tuple_create, 1);
577    bind!(__quantum__rt__tuple_update_alias_count, 2);
578    bind!(__quantum__rt__tuple_update_reference_count, 2);
579
580    if !(uses_legacy.iter().filter_map(|&b| b).all(|b| b)
581        || uses_legacy.iter().filter_map(|&b| b).all(|b| !b))
582    {
583        Err("Use of legacy and current output recording functions in the same program is not supported".to_string())
584    } else if declarations.is_empty() {
585        Ok(())
586    } else {
587        let keys = declarations.keys().collect::<Vec<_>>();
588        let (first, rest) = keys
589            .split_first()
590            .expect("Declarations list should be non-empty.");
591        Err(format!(
592            "Failed to link some declared functions: {}",
593            rest.iter().fold((*first).clone(), |mut accum, f| {
594                accum.push_str(", ");
595                accum.push_str(f);
596                accum
597            })
598        ))
599    }
600}