runner/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![deny(clippy::all, clippy::pedantic)]
5#![allow(unused)]
6
7mod cli;
8pub use cli::main;
9
10pub use qir_backend::{
11    arrays::*, bigints::*, callables::*, exp::*, math::*, output_recording::*, range_support::*,
12    result_bool::*, strings::*, tuples::*, *,
13};
14
15use inkwell::{
16    OptimizationLevel,
17    attributes::AttributeLoc,
18    context::Context,
19    execution_engine::ExecutionEngine,
20    memory_buffer::MemoryBuffer,
21    module::Module,
22    passes::{PassBuilderOptions, PassManager},
23    targets::{CodeModel, InitializationConfig, RelocMode, Target, TargetMachine, TargetTriple},
24    values::FunctionValue,
25};
26use std::{
27    collections::HashMap,
28    ffi::OsStr,
29    io::{Read, Write},
30    path::Path,
31    ptr::null_mut,
32};
33
34/// # Errors
35///
36/// Will return `Err` if
37/// - `filename` does not exist or the user does not have permission to read it.
38/// - `filename` does not contain a valid bitcode module
39/// - `filename` does not have either a .ll or .bc as an extension
40/// - `entry_point` is not found in the QIR
41/// - Entry point has parameters or a non-void return type.
42pub fn run_file(
43    path: impl AsRef<Path>,
44    entry_point: Option<&str>,
45    shots: u32,
46    rng_seed: Option<u64>,
47    output_writer: &mut impl Write,
48) -> Result<(), String> {
49    if let Some(seed) = rng_seed {
50        qir_backend::set_rng_seed(seed);
51    }
52    let context = Context::create();
53    let module = load_file(path, &context)?;
54    run_module(&module, entry_point, shots, output_writer)
55}
56
57/// # Errors
58///
59/// Will return `Err` if
60/// - `bytes` does not contain a valid bitcode module
61/// - `entry_point` is not found in the QIR
62/// - Entry point has parameters or a non-void return type.
63pub fn run_bitcode(
64    bytes: &[u8],
65    entry_point: Option<&str>,
66    shots: u32,
67    output_writer: &mut impl Write,
68) -> Result<(), String> {
69    let context = Context::create();
70    let buffer = MemoryBuffer::create_from_memory_range(bytes, "");
71    let module = Module::parse_bitcode_from_buffer(&buffer, &context).map_err(|e| e.to_string())?;
72    run_module(&module, entry_point, shots, output_writer)
73}
74
75fn run_module(
76    module: &Module,
77    entry_point: Option<&str>,
78    shots: u32,
79    output_writer: &mut impl Write,
80) -> Result<(), String> {
81    module
82        .verify()
83        .map_err(|e| format!("Failed to verify module: {}", e.to_string()))?;
84
85    Target::initialize_native(&InitializationConfig::default())?;
86    let default_triple = TargetMachine::get_default_triple();
87    let target = Target::from_triple(&default_triple).map_err(|e| e.to_string())?;
88    if !target.has_asm_backend() {
89        return Err("Target doesn't have an ASM backend.".to_owned());
90    }
91    if !target.has_target_machine() {
92        return Err("Target doesn't have a target machine.".to_owned());
93    }
94
95    run_basic_passes_on(module, &default_triple, &target)?;
96
97    inkwell::support::load_library_permanently(Path::new(""));
98
99    let execution_engine = module
100        .create_jit_execution_engine(OptimizationLevel::None)
101        .map_err(|e| e.to_string())?;
102
103    bind_functions(module, &execution_engine)?;
104
105    let entry_point = choose_entry_point(module_functions(module), entry_point)?;
106    // TODO: need a cleaner way to get the attr strings for metadata
107    let attrs: Vec<(String, String)> = entry_point
108        .attributes(AttributeLoc::Function)
109        .iter()
110        .map(|attr| {
111            (
112                attr.get_string_kind_id()
113                    .to_str()
114                    .expect("Invalid UTF8 data")
115                    .to_string(),
116                attr.get_string_value()
117                    .to_str()
118                    .expect("Invalid UTF8 data")
119                    .to_string(),
120            )
121        })
122        .collect();
123
124    for _ in 1..=shots {
125        output_writer
126            .write_all("START\n".as_bytes())
127            .expect("Failed to write output");
128        for attr in &attrs {
129            output_writer
130                .write_all(format!("METADATA\t{}", attr.0).as_bytes())
131                .expect("Failed to write output");
132            if !attr.1.is_empty() {
133                output_writer
134                    .write_all(format!("\t{}", attr.1).as_bytes())
135                    .expect("Failed to write output");
136            }
137            output_writer
138                .write_all(qir_stdlib::output_recording::LINE_ENDING)
139                .expect("Failed to write output");
140        }
141
142        __quantum__rt__initialize(null_mut());
143        unsafe { run_entry_point(&execution_engine, entry_point)? }
144
145        // Write the saved output records to the output_writer
146        OUTPUT.with(|output| {
147            let mut output = output.borrow_mut();
148            output_writer
149                .write_all(output.drain().as_slice())
150                .expect("Failed to write output");
151        });
152
153        // Write the end of the shot
154        output_writer
155            .write_all("END\t0".as_bytes())
156            .expect("Failed to write output");
157        output_writer
158            .write_all(qir_stdlib::output_recording::LINE_ENDING)
159            .expect("Failed to write output");
160    }
161    Ok(())
162}
163
164fn load_file(path: impl AsRef<Path>, context: &Context) -> Result<Module, String> {
165    let path = path.as_ref();
166    let extension = path.extension().and_then(OsStr::to_str);
167
168    match extension {
169        Some("ll") => MemoryBuffer::create_from_file(path)
170            .and_then(|buffer| context.create_module_from_ir(buffer))
171            .map_err(|e| e.to_string()),
172        Some("bc") => Module::parse_bitcode_from_path(path, context).map_err(|e| e.to_string()),
173        _ => Err(format!("Unsupported file extension '{extension:?}'.")),
174    }
175}
176
177unsafe fn run_entry_point(
178    execution_engine: &ExecutionEngine,
179    entry_point: FunctionValue,
180) -> Result<(), String> {
181    unsafe {
182        if entry_point.count_params() == 0 {
183            execution_engine.run_function(entry_point, &[]);
184            Ok(())
185        } else {
186            Err("Entry point has parameters or a non-void return type.".to_owned())
187        }
188    }
189}
190
191fn choose_entry_point<'ctx>(
192    functions: impl Iterator<Item = FunctionValue<'ctx>>,
193    name: Option<&str>,
194) -> Result<FunctionValue<'ctx>, String> {
195    let mut entry_points = functions
196        .filter(|f| is_entry_point(*f) && name.iter().all(|n| f.get_name().to_str() == Ok(n)));
197
198    let entry_point = entry_points
199        .next()
200        .ok_or_else(|| "No matching entry point found.".to_owned())?;
201
202    if entry_points.next().is_some() {
203        Err("Multiple matching entry points found.".to_owned())
204    } else {
205        Ok(entry_point)
206    }
207}
208
209fn module_functions<'ctx>(module: &Module<'ctx>) -> impl Iterator<Item = FunctionValue<'ctx>> {
210    struct FunctionValueIter<'ctx>(Option<FunctionValue<'ctx>>);
211
212    impl<'ctx> Iterator for FunctionValueIter<'ctx> {
213        type Item = FunctionValue<'ctx>;
214
215        fn next(&mut self) -> Option<Self::Item> {
216            let function = self.0;
217            self.0 = function.and_then(inkwell::values::FunctionValue::get_next_function);
218            function
219        }
220    }
221
222    FunctionValueIter(module.get_first_function())
223}
224
225fn is_entry_point(function: FunctionValue) -> bool {
226    function
227        .get_string_attribute(AttributeLoc::Function, "entry_point")
228        .is_some()
229        || function
230            .get_string_attribute(AttributeLoc::Function, "EntryPoint")
231            .is_some()
232}
233
234fn run_basic_passes_on(
235    module: &Module,
236    target_triple: &TargetTriple,
237    target: &Target,
238) -> Result<(), String> {
239    // Description of this syntax:
240    // https://github.com/llvm/llvm-project/blob/2ba08386156ef25913b1bee170d8fe95aaceb234/llvm/include/llvm/Passes/PassBuilder.h#L308-L347
241    const BASIC_PASS_PIPELINE: &str = "globaldce,strip-dead-prototypes";
242
243    // Boilerplate taken from here:
244    // https://github.com/TheDan64/inkwell/blob/5c9f7fcbb0a667f7391b94beb65f1a670ad13221/examples/kaleidoscope/main.rs#L86-L95
245    let target_machine = target
246        .create_target_machine(
247            target_triple,
248            "generic",
249            "",
250            OptimizationLevel::None,
251            RelocMode::Default,
252            CodeModel::Default,
253        )
254        .ok_or("Unable to create TargetMachine from Target")?;
255    module
256        .run_passes(
257            BASIC_PASS_PIPELINE,
258            &target_machine,
259            PassBuilderOptions::create(),
260        )
261        .map_err(|e| e.to_string())
262}
263
264#[allow(clippy::too_many_lines)]
265fn bind_functions(module: &Module, execution_engine: &ExecutionEngine) -> Result<(), String> {
266    let mut uses_legacy = vec![];
267    let mut declarations: HashMap<String, FunctionValue> = HashMap::default();
268    for func in module_functions(module).filter(|f| {
269        f.count_basic_blocks() == 0
270            && !f
271                .get_name()
272                .to_str()
273                .expect("Unable to coerce function name into str.")
274                .starts_with("llvm.")
275    }) {
276        declarations.insert(
277            func.get_name()
278                .to_str()
279                .expect("Unable to coerce function name into str.")
280                .to_owned(),
281            func,
282        );
283    }
284
285    macro_rules! bind {
286        ($func:ident, $param_count:expr) => {
287            if let Some(func) = declarations.get(stringify!($func)) {
288                if func.get_params().len() != $param_count {
289                    return Err(format!(
290                        "Function '{}' has mismatched parameters: expected {}, found {}",
291                        stringify!($func),
292                        $param_count,
293                        func.get_params().len()
294                    ));
295                }
296                execution_engine.add_global_mapping(func, $func as usize);
297                declarations.remove(stringify!($func));
298            }
299        };
300    }
301
302    macro_rules! legacy_output {
303        ($func:ident) => {
304            if let Some(func) = declarations.get(stringify!($func)) {
305                execution_engine.add_global_mapping(
306                    func,
307                    qir_backend::output_recording::legacy::$func as usize,
308                );
309                declarations.remove(stringify!($func));
310                Some(true)
311            } else {
312                None
313            }
314        };
315    }
316
317    macro_rules! bind_output_record {
318        ($func:ident) => {
319            if let Some(func) = declarations.get(stringify!($func)) {
320                if func.get_params().len() == 1 {
321                    execution_engine.add_global_mapping(
322                        func,
323                        qir_backend::output_recording::legacy::$func as usize,
324                    );
325                    declarations.remove(stringify!($func));
326                    Some(true)
327                } else {
328                    execution_engine.add_global_mapping(func, $func as usize);
329                    declarations.remove(stringify!($func));
330                    Some(false)
331                }
332            } else {
333                None
334            }
335        };
336    }
337
338    // Legacy output methods
339    uses_legacy.push(legacy_output!(__quantum__rt__array_end_record_output));
340    uses_legacy.push(legacy_output!(__quantum__rt__array_start_record_output));
341    uses_legacy.push(legacy_output!(__quantum__rt__tuple_end_record_output));
342    uses_legacy.push(legacy_output!(__quantum__rt__tuple_start_record_output));
343
344    bind!(__quantum__rt__initialize, 1);
345    bind!(__quantum__qis__arccos__body, 1);
346    bind!(__quantum__qis__arcsin__body, 1);
347    bind!(__quantum__qis__arctan__body, 1);
348    bind!(__quantum__qis__arctan2__body, 2);
349    bind!(__quantum__qis__assertmeasurementprobability__body, 6);
350    bind!(__quantum__qis__assertmeasurementprobability__ctl, 6);
351    bind!(__quantum__qis__barrier__body, 0);
352    bind!(__quantum__qis__ccx__body, 3);
353    bind!(__quantum__qis__cnot__body, 2);
354    bind!(__quantum__qis__cos__body, 1);
355    bind!(__quantum__qis__cosh__body, 1);
356    bind!(__quantum__qis__cx__body, 2);
357    bind!(__quantum__qis__cz__body, 2);
358    bind!(__quantum__qis__drawrandomdouble__body, 2);
359    bind!(__quantum__qis__drawrandomint__body, 2);
360    bind!(__quantum__qis__dumpmachine__body, 1);
361    bind!(__quantum__qis__exp__body, 3);
362    bind!(__quantum__qis__exp__adj, 3);
363    bind!(__quantum__qis__exp__ctl, 2);
364    bind!(__quantum__qis__exp__ctladj, 2);
365    bind!(__quantum__qis__h__body, 1);
366    bind!(__quantum__qis__h__ctl, 2);
367    bind!(__quantum__qis__ieeeremainder__body, 2);
368    bind!(__quantum__qis__infinity__body, 0);
369    bind!(__quantum__qis__isinf__body, 1);
370    bind!(__quantum__qis__isnan__body, 1);
371    bind!(__quantum__qis__isnegativeinfinity__body, 1);
372    bind!(__quantum__qis__log__body, 1);
373    bind!(__quantum__qis__measure__body, 2);
374    bind!(__quantum__qis__mresetz__body, 2);
375    bind!(__quantum__qis__mz__body, 2);
376    bind!(__quantum__qis__nan__body, 0);
377    bind!(__quantum__qis__r__adj, 3);
378    bind!(__quantum__qis__r__body, 3);
379    bind!(__quantum__qis__r__ctl, 2);
380    bind!(__quantum__qis__r__ctladj, 2);
381    bind!(__quantum__qis__read_result__body, 1);
382    bind!(__quantum__qis__reset__body, 1);
383    bind!(__quantum__qis__rx__body, 2);
384    bind!(__quantum__qis__rx__ctl, 2);
385    bind!(__quantum__qis__rxx__body, 3);
386    bind!(__quantum__qis__ry__body, 2);
387    bind!(__quantum__qis__ry__ctl, 2);
388    bind!(__quantum__qis__ryy__body, 3);
389    bind!(__quantum__qis__rz__body, 2);
390    bind!(__quantum__qis__rz__ctl, 2);
391    bind!(__quantum__qis__rzz__body, 3);
392    bind!(__quantum__qis__s__adj, 1);
393    bind!(__quantum__qis__s__body, 1);
394    bind!(__quantum__qis__s__ctl, 2);
395    bind!(__quantum__qis__s__ctladj, 2);
396    bind!(__quantum__qis__sx__body, 1);
397    bind!(__quantum__qis__sin__body, 1);
398    bind!(__quantum__qis__sinh__body, 1);
399    bind!(__quantum__qis__sqrt__body, 1);
400    bind!(__quantum__qis__swap__body, 2);
401    bind!(__quantum__qis__t__adj, 1);
402    bind!(__quantum__qis__t__body, 1);
403    bind!(__quantum__qis__t__ctl, 2);
404    bind!(__quantum__qis__t__ctladj, 2);
405    bind!(__quantum__qis__tan__body, 1);
406    bind!(__quantum__qis__tanh__body, 1);
407    bind!(__quantum__qis__x__body, 1);
408    bind!(__quantum__qis__x__ctl, 2);
409    bind!(__quantum__qis__y__body, 1);
410    bind!(__quantum__qis__y__ctl, 2);
411    bind!(__quantum__qis__z__body, 1);
412    bind!(__quantum__qis__z__ctl, 2);
413    bind!(__quantum__rt__array_concatenate, 2);
414    bind!(__quantum__rt__array_copy, 2);
415    bind!(__quantum__rt__array_create_1d, 2);
416
417    // New calls
418    bind!(__quantum__rt__array_record_output, 2);
419    bind!(__quantum__rt__tuple_record_output, 2);
420
421    // calls with unlabeled signature variants
422    uses_legacy.push(bind_output_record!(__quantum__rt__bool_record_output));
423    uses_legacy.push(bind_output_record!(__quantum__rt__double_record_output));
424    uses_legacy.push(bind_output_record!(__quantum__rt__int_record_output));
425
426    // results need special handling as they aren't in the std lib
427    uses_legacy.push(
428        if let Some(func) = declarations.get("__quantum__rt__result_record_output") {
429            if func.get_params().len() == 1 {
430                execution_engine.add_global_mapping(
431                    func,
432                    qir_backend::legacy_output::__quantum__rt__result_record_output as usize,
433                );
434                declarations.remove("__quantum__rt__result_record_output");
435                Some(true)
436            } else {
437                execution_engine
438                    .add_global_mapping(func, __quantum__rt__result_record_output as usize);
439                declarations.remove("__quantum__rt__result_record_output");
440                Some(false)
441            }
442        } else {
443            None
444        },
445    );
446
447    // calls to __quantum__qis__m__body may use either dynamic or static results, so bind to the right
448    // implementation based on number of arguments.
449    if let Some(func) = declarations.get("__quantum__qis__m__body") {
450        if func.get_params().len() == 2 {
451            execution_engine
452                .add_global_mapping(func, qir_backend::__quantum__qis__mz__body as usize);
453        } else if func.get_params().len() == 1 {
454            execution_engine
455                .add_global_mapping(func, qir_backend::__quantum__qis__m__body as usize);
456        } else {
457            return Err(format!(
458                "Function '__quantum__qis__m__body' has mismatched parameters: expected 1 or 2, found {}",
459                func.get_params().len()
460            ));
461        }
462        declarations.remove("__quantum__qis__m__body");
463    }
464
465    bind!(__quantum__rt__array_get_element_ptr_1d, 2);
466    bind!(__quantum__rt__array_get_size_1d, 1);
467    bind!(quantum__rt__array_slice_1d, 3);
468    bind!(__quantum__rt__array_update_alias_count, 2);
469    bind!(__quantum__rt__array_update_reference_count, 2);
470    bind!(__quantum__rt__bigint_add, 2);
471    bind!(__quantum__rt__bigint_bitand, 2);
472    bind!(__quantum__rt__bigint_bitnot, 1);
473    bind!(__quantum__rt__bigint_bitor, 2);
474    bind!(__quantum__rt__bigint_bitxor, 2);
475    bind!(__quantum__rt__bigint_create_array, 2);
476    bind!(__quantum__rt__bigint_create_i64, 1);
477    bind!(__quantum__rt__bigint_divide, 2);
478    bind!(__quantum__rt__bigint_equal, 2);
479    bind!(__quantum__rt__bigint_get_data, 1);
480    bind!(__quantum__rt__bigint_get_length, 1);
481    bind!(__quantum__rt__bigint_greater, 2);
482    bind!(__quantum__rt__bigint_greater_eq, 2);
483    bind!(__quantum__rt__bigint_modulus, 2);
484    bind!(__quantum__rt__bigint_multiply, 2);
485    bind!(__quantum__rt__bigint_negate, 1);
486    bind!(__quantum__rt__bigint_power, 2);
487    bind!(__quantum__rt__bigint_shiftleft, 2);
488    bind!(__quantum__rt__bigint_shiftright, 2);
489    bind!(__quantum__rt__bigint_subtract, 2);
490    bind!(__quantum__rt__bigint_to_string, 1);
491    bind!(__quantum__rt__bigint_update_reference_count, 2);
492    bind!(__quantum__rt__bool_to_string, 1);
493    bind!(__quantum__rt__callable_copy, 2);
494    bind!(__quantum__rt__callable_create, 3);
495    bind!(__quantum__rt__callable_invoke, 3);
496    bind!(__quantum__rt__callable_make_adjoint, 1);
497    bind!(__quantum__rt__callable_make_controlled, 1);
498    bind!(__quantum__rt__callable_update_alias_count, 2);
499    bind!(__quantum__rt__callable_update_reference_count, 2);
500    bind!(__quantum__rt__capture_update_alias_count, 2);
501    bind!(__quantum__rt__capture_update_reference_count, 2);
502    bind!(__quantum__rt__double_to_string, 1);
503    bind!(__quantum__rt__fail, 1);
504    bind!(__quantum__rt__int_to_string, 1);
505    bind!(__quantum__rt__memory_allocate, 1);
506    bind!(__quantum__rt__message, 1);
507    bind!(__quantum__rt__pauli_to_string, 1);
508    bind!(__quantum__rt__qubit_allocate, 0);
509    bind!(__quantum__rt__qubit_allocate_array, 1);
510    bind!(__quantum__rt__qubit_release, 1);
511    bind!(__quantum__rt__qubit_release_array, 1);
512    bind!(__quantum__rt__qubit_to_string, 1);
513    bind!(__quantum__rt__read_result, 1);
514    bind!(__quantum__rt__result_equal, 2);
515    bind!(quantum__rt__range_to_string, 1);
516    bind!(__quantum__rt__result_get_one, 0);
517    bind!(__quantum__rt__result_get_zero, 0);
518    bind!(__quantum__rt__result_to_string, 1);
519    bind!(__quantum__rt__result_update_reference_count, 2);
520    bind!(__quantum__rt__string_concatenate, 2);
521    bind!(__quantum__rt__string_create, 1);
522    bind!(__quantum__rt__string_equal, 2);
523    bind!(__quantum__rt__string_get_data, 1);
524    bind!(__quantum__rt__string_get_length, 1);
525    bind!(__quantum__rt__string_update_reference_count, 2);
526    bind!(__quantum__rt__tuple_copy, 2);
527    bind!(__quantum__rt__tuple_create, 1);
528    bind!(__quantum__rt__tuple_update_alias_count, 2);
529    bind!(__quantum__rt__tuple_update_reference_count, 2);
530
531    if !(uses_legacy.iter().filter_map(|&b| b).all(|b| b)
532        || uses_legacy.iter().filter_map(|&b| b).all(|b| !b))
533    {
534        Err("Use of legacy and current output recording functions in the same program is not supported".to_string())
535    } else if declarations.is_empty() {
536        Ok(())
537    } else {
538        let keys = declarations.keys().collect::<Vec<_>>();
539        let (first, rest) = keys
540            .split_first()
541            .expect("Declarations list should be non-empty.");
542        Err(format!(
543            "Failed to link some declared functions: {}",
544            rest.iter().fold((*first).to_string(), |mut accum, f| {
545                accum.push_str(", ");
546                accum.push_str(f);
547                accum
548            })
549        ))
550    }
551}