1#![deny(clippy::all, clippy::pedantic)]
5#![allow(unused)]
6
7mod cli;
8pub use cli::main;
9
10pub use qir_backend::{
11 arrays::*, bigints::*, callables::*, exp::*, math::*, output_recording::*, range_support::*,
12 result_bool::*, strings::*, tuples::*, *,
13};
14
15use inkwell::{
16 OptimizationLevel,
17 attributes::AttributeLoc,
18 context::Context,
19 execution_engine::ExecutionEngine,
20 memory_buffer::MemoryBuffer,
21 module::Module,
22 passes::{PassBuilderOptions, PassManager},
23 targets::{CodeModel, InitializationConfig, RelocMode, Target, TargetMachine, TargetTriple},
24 values::FunctionValue,
25};
26use std::{
27 collections::HashMap,
28 ffi::OsStr,
29 io::{Read, Write},
30 path::Path,
31 ptr::null_mut,
32};
33
34pub fn run_file(
43 path: impl AsRef<Path>,
44 entry_point: Option<&str>,
45 shots: u32,
46 rng_seed: Option<u64>,
47 output_writer: &mut impl Write,
48) -> Result<(), String> {
49 if let Some(seed) = rng_seed {
50 qir_backend::set_rng_seed(seed);
51 }
52 let context = Context::create();
53 let module = load_file(path, &context)?;
54 run_module(&module, entry_point, shots, output_writer)
55}
56
57pub fn run_bitcode(
64 bytes: &[u8],
65 entry_point: Option<&str>,
66 shots: u32,
67 output_writer: &mut impl Write,
68) -> Result<(), String> {
69 let context = Context::create();
70 let buffer = MemoryBuffer::create_from_memory_range(bytes, "");
71 let module = Module::parse_bitcode_from_buffer(&buffer, &context).map_err(|e| e.to_string())?;
72 run_module(&module, entry_point, shots, output_writer)
73}
74
75fn run_module(
76 module: &Module,
77 entry_point: Option<&str>,
78 shots: u32,
79 output_writer: &mut impl Write,
80) -> Result<(), String> {
81 module
82 .verify()
83 .map_err(|e| format!("Failed to verify module: {}", e.to_string()))?;
84
85 Target::initialize_native(&InitializationConfig::default())?;
86 let default_triple = TargetMachine::get_default_triple();
87 let target = Target::from_triple(&default_triple).map_err(|e| e.to_string())?;
88 if !target.has_asm_backend() {
89 return Err("Target doesn't have an ASM backend.".to_owned());
90 }
91 if !target.has_target_machine() {
92 return Err("Target doesn't have a target machine.".to_owned());
93 }
94
95 run_basic_passes_on(module, &default_triple, &target)?;
96
97 inkwell::support::load_library_permanently(Path::new(""));
98
99 let execution_engine = module
100 .create_jit_execution_engine(OptimizationLevel::None)
101 .map_err(|e| e.to_string())?;
102
103 bind_functions(module, &execution_engine)?;
104
105 let entry_point = choose_entry_point(module_functions(module), entry_point)?;
106 let attrs: Vec<(String, String)> = entry_point
108 .attributes(AttributeLoc::Function)
109 .iter()
110 .map(|attr| {
111 (
112 attr.get_string_kind_id()
113 .to_str()
114 .expect("Invalid UTF8 data")
115 .to_string(),
116 attr.get_string_value()
117 .to_str()
118 .expect("Invalid UTF8 data")
119 .to_string(),
120 )
121 })
122 .collect();
123
124 for _ in 1..=shots {
125 output_writer
126 .write_all("START\n".as_bytes())
127 .expect("Failed to write output");
128 for attr in &attrs {
129 output_writer
130 .write_all(format!("METADATA\t{}", attr.0).as_bytes())
131 .expect("Failed to write output");
132 if !attr.1.is_empty() {
133 output_writer
134 .write_all(format!("\t{}", attr.1).as_bytes())
135 .expect("Failed to write output");
136 }
137 output_writer
138 .write_all(qir_stdlib::output_recording::LINE_ENDING)
139 .expect("Failed to write output");
140 }
141
142 __quantum__rt__initialize(null_mut());
143 unsafe { run_entry_point(&execution_engine, entry_point)? }
144
145 OUTPUT.with(|output| {
147 let mut output = output.borrow_mut();
148 output_writer
149 .write_all(output.drain().as_slice())
150 .expect("Failed to write output");
151 });
152
153 output_writer
155 .write_all("END\t0".as_bytes())
156 .expect("Failed to write output");
157 output_writer
158 .write_all(qir_stdlib::output_recording::LINE_ENDING)
159 .expect("Failed to write output");
160 }
161 Ok(())
162}
163
164fn load_file(path: impl AsRef<Path>, context: &Context) -> Result<Module, String> {
165 let path = path.as_ref();
166 let extension = path.extension().and_then(OsStr::to_str);
167
168 match extension {
169 Some("ll") => MemoryBuffer::create_from_file(path)
170 .and_then(|buffer| context.create_module_from_ir(buffer))
171 .map_err(|e| e.to_string()),
172 Some("bc") => Module::parse_bitcode_from_path(path, context).map_err(|e| e.to_string()),
173 _ => Err(format!("Unsupported file extension '{extension:?}'.")),
174 }
175}
176
177unsafe fn run_entry_point(
178 execution_engine: &ExecutionEngine,
179 entry_point: FunctionValue,
180) -> Result<(), String> {
181 unsafe {
182 if entry_point.count_params() == 0 {
183 execution_engine.run_function(entry_point, &[]);
184 Ok(())
185 } else {
186 Err("Entry point has parameters or a non-void return type.".to_owned())
187 }
188 }
189}
190
191fn choose_entry_point<'ctx>(
192 functions: impl Iterator<Item = FunctionValue<'ctx>>,
193 name: Option<&str>,
194) -> Result<FunctionValue<'ctx>, String> {
195 let mut entry_points = functions
196 .filter(|f| is_entry_point(*f) && name.iter().all(|n| f.get_name().to_str() == Ok(n)));
197
198 let entry_point = entry_points
199 .next()
200 .ok_or_else(|| "No matching entry point found.".to_owned())?;
201
202 if entry_points.next().is_some() {
203 Err("Multiple matching entry points found.".to_owned())
204 } else {
205 Ok(entry_point)
206 }
207}
208
209fn module_functions<'ctx>(module: &Module<'ctx>) -> impl Iterator<Item = FunctionValue<'ctx>> {
210 struct FunctionValueIter<'ctx>(Option<FunctionValue<'ctx>>);
211
212 impl<'ctx> Iterator for FunctionValueIter<'ctx> {
213 type Item = FunctionValue<'ctx>;
214
215 fn next(&mut self) -> Option<Self::Item> {
216 let function = self.0;
217 self.0 = function.and_then(inkwell::values::FunctionValue::get_next_function);
218 function
219 }
220 }
221
222 FunctionValueIter(module.get_first_function())
223}
224
225fn is_entry_point(function: FunctionValue) -> bool {
226 function
227 .get_string_attribute(AttributeLoc::Function, "entry_point")
228 .is_some()
229 || function
230 .get_string_attribute(AttributeLoc::Function, "EntryPoint")
231 .is_some()
232}
233
234fn run_basic_passes_on(
235 module: &Module,
236 target_triple: &TargetTriple,
237 target: &Target,
238) -> Result<(), String> {
239 const BASIC_PASS_PIPELINE: &str = "globaldce,strip-dead-prototypes";
242
243 let target_machine = target
246 .create_target_machine(
247 target_triple,
248 "generic",
249 "",
250 OptimizationLevel::None,
251 RelocMode::Default,
252 CodeModel::Default,
253 )
254 .ok_or("Unable to create TargetMachine from Target")?;
255 module
256 .run_passes(
257 BASIC_PASS_PIPELINE,
258 &target_machine,
259 PassBuilderOptions::create(),
260 )
261 .map_err(|e| e.to_string())
262}
263
264#[allow(clippy::too_many_lines)]
265fn bind_functions(module: &Module, execution_engine: &ExecutionEngine) -> Result<(), String> {
266 let mut uses_legacy = vec![];
267 let mut declarations: HashMap<String, FunctionValue> = HashMap::default();
268 for func in module_functions(module).filter(|f| {
269 f.count_basic_blocks() == 0
270 && !f
271 .get_name()
272 .to_str()
273 .expect("Unable to coerce function name into str.")
274 .starts_with("llvm.")
275 }) {
276 declarations.insert(
277 func.get_name()
278 .to_str()
279 .expect("Unable to coerce function name into str.")
280 .to_owned(),
281 func,
282 );
283 }
284
285 macro_rules! bind {
286 ($func:ident, $param_count:expr) => {
287 if let Some(func) = declarations.get(stringify!($func)) {
288 if func.get_params().len() != $param_count {
289 return Err(format!(
290 "Function '{}' has mismatched parameters: expected {}, found {}",
291 stringify!($func),
292 $param_count,
293 func.get_params().len()
294 ));
295 }
296 execution_engine.add_global_mapping(func, $func as usize);
297 declarations.remove(stringify!($func));
298 }
299 };
300 }
301
302 macro_rules! legacy_output {
303 ($func:ident) => {
304 if let Some(func) = declarations.get(stringify!($func)) {
305 execution_engine.add_global_mapping(
306 func,
307 qir_backend::output_recording::legacy::$func as usize,
308 );
309 declarations.remove(stringify!($func));
310 Some(true)
311 } else {
312 None
313 }
314 };
315 }
316
317 macro_rules! bind_output_record {
318 ($func:ident) => {
319 if let Some(func) = declarations.get(stringify!($func)) {
320 if func.get_params().len() == 1 {
321 execution_engine.add_global_mapping(
322 func,
323 qir_backend::output_recording::legacy::$func as usize,
324 );
325 declarations.remove(stringify!($func));
326 Some(true)
327 } else {
328 execution_engine.add_global_mapping(func, $func as usize);
329 declarations.remove(stringify!($func));
330 Some(false)
331 }
332 } else {
333 None
334 }
335 };
336 }
337
338 uses_legacy.push(legacy_output!(__quantum__rt__array_end_record_output));
340 uses_legacy.push(legacy_output!(__quantum__rt__array_start_record_output));
341 uses_legacy.push(legacy_output!(__quantum__rt__tuple_end_record_output));
342 uses_legacy.push(legacy_output!(__quantum__rt__tuple_start_record_output));
343
344 bind!(__quantum__rt__initialize, 1);
345 bind!(__quantum__qis__arccos__body, 1);
346 bind!(__quantum__qis__arcsin__body, 1);
347 bind!(__quantum__qis__arctan__body, 1);
348 bind!(__quantum__qis__arctan2__body, 2);
349 bind!(__quantum__qis__assertmeasurementprobability__body, 6);
350 bind!(__quantum__qis__assertmeasurementprobability__ctl, 6);
351 bind!(__quantum__qis__barrier__body, 0);
352 bind!(__quantum__qis__ccx__body, 3);
353 bind!(__quantum__qis__cnot__body, 2);
354 bind!(__quantum__qis__cos__body, 1);
355 bind!(__quantum__qis__cosh__body, 1);
356 bind!(__quantum__qis__cx__body, 2);
357 bind!(__quantum__qis__cz__body, 2);
358 bind!(__quantum__qis__drawrandomdouble__body, 2);
359 bind!(__quantum__qis__drawrandomint__body, 2);
360 bind!(__quantum__qis__dumpmachine__body, 1);
361 bind!(__quantum__qis__exp__body, 3);
362 bind!(__quantum__qis__exp__adj, 3);
363 bind!(__quantum__qis__exp__ctl, 2);
364 bind!(__quantum__qis__exp__ctladj, 2);
365 bind!(__quantum__qis__h__body, 1);
366 bind!(__quantum__qis__h__ctl, 2);
367 bind!(__quantum__qis__ieeeremainder__body, 2);
368 bind!(__quantum__qis__infinity__body, 0);
369 bind!(__quantum__qis__isinf__body, 1);
370 bind!(__quantum__qis__isnan__body, 1);
371 bind!(__quantum__qis__isnegativeinfinity__body, 1);
372 bind!(__quantum__qis__log__body, 1);
373 bind!(__quantum__qis__measure__body, 2);
374 bind!(__quantum__qis__mresetz__body, 2);
375 bind!(__quantum__qis__mz__body, 2);
376 bind!(__quantum__qis__nan__body, 0);
377 bind!(__quantum__qis__r__adj, 3);
378 bind!(__quantum__qis__r__body, 3);
379 bind!(__quantum__qis__r__ctl, 2);
380 bind!(__quantum__qis__r__ctladj, 2);
381 bind!(__quantum__qis__read_result__body, 1);
382 bind!(__quantum__qis__reset__body, 1);
383 bind!(__quantum__qis__rx__body, 2);
384 bind!(__quantum__qis__rx__ctl, 2);
385 bind!(__quantum__qis__rxx__body, 3);
386 bind!(__quantum__qis__ry__body, 2);
387 bind!(__quantum__qis__ry__ctl, 2);
388 bind!(__quantum__qis__ryy__body, 3);
389 bind!(__quantum__qis__rz__body, 2);
390 bind!(__quantum__qis__rz__ctl, 2);
391 bind!(__quantum__qis__rzz__body, 3);
392 bind!(__quantum__qis__s__adj, 1);
393 bind!(__quantum__qis__s__body, 1);
394 bind!(__quantum__qis__s__ctl, 2);
395 bind!(__quantum__qis__s__ctladj, 2);
396 bind!(__quantum__qis__sx__body, 1);
397 bind!(__quantum__qis__sin__body, 1);
398 bind!(__quantum__qis__sinh__body, 1);
399 bind!(__quantum__qis__sqrt__body, 1);
400 bind!(__quantum__qis__swap__body, 2);
401 bind!(__quantum__qis__t__adj, 1);
402 bind!(__quantum__qis__t__body, 1);
403 bind!(__quantum__qis__t__ctl, 2);
404 bind!(__quantum__qis__t__ctladj, 2);
405 bind!(__quantum__qis__tan__body, 1);
406 bind!(__quantum__qis__tanh__body, 1);
407 bind!(__quantum__qis__x__body, 1);
408 bind!(__quantum__qis__x__ctl, 2);
409 bind!(__quantum__qis__y__body, 1);
410 bind!(__quantum__qis__y__ctl, 2);
411 bind!(__quantum__qis__z__body, 1);
412 bind!(__quantum__qis__z__ctl, 2);
413 bind!(__quantum__rt__array_concatenate, 2);
414 bind!(__quantum__rt__array_copy, 2);
415 bind!(__quantum__rt__array_create_1d, 2);
416
417 bind!(__quantum__rt__array_record_output, 2);
419 bind!(__quantum__rt__tuple_record_output, 2);
420
421 uses_legacy.push(bind_output_record!(__quantum__rt__bool_record_output));
423 uses_legacy.push(bind_output_record!(__quantum__rt__double_record_output));
424 uses_legacy.push(bind_output_record!(__quantum__rt__int_record_output));
425
426 uses_legacy.push(
428 if let Some(func) = declarations.get("__quantum__rt__result_record_output") {
429 if func.get_params().len() == 1 {
430 execution_engine.add_global_mapping(
431 func,
432 qir_backend::legacy_output::__quantum__rt__result_record_output as usize,
433 );
434 declarations.remove("__quantum__rt__result_record_output");
435 Some(true)
436 } else {
437 execution_engine
438 .add_global_mapping(func, __quantum__rt__result_record_output as usize);
439 declarations.remove("__quantum__rt__result_record_output");
440 Some(false)
441 }
442 } else {
443 None
444 },
445 );
446
447 if let Some(func) = declarations.get("__quantum__qis__m__body") {
450 if func.get_params().len() == 2 {
451 execution_engine
452 .add_global_mapping(func, qir_backend::__quantum__qis__mz__body as usize);
453 } else if func.get_params().len() == 1 {
454 execution_engine
455 .add_global_mapping(func, qir_backend::__quantum__qis__m__body as usize);
456 } else {
457 return Err(format!(
458 "Function '__quantum__qis__m__body' has mismatched parameters: expected 1 or 2, found {}",
459 func.get_params().len()
460 ));
461 }
462 declarations.remove("__quantum__qis__m__body");
463 }
464
465 bind!(__quantum__rt__array_get_element_ptr_1d, 2);
466 bind!(__quantum__rt__array_get_size_1d, 1);
467 bind!(quantum__rt__array_slice_1d, 3);
468 bind!(__quantum__rt__array_update_alias_count, 2);
469 bind!(__quantum__rt__array_update_reference_count, 2);
470 bind!(__quantum__rt__bigint_add, 2);
471 bind!(__quantum__rt__bigint_bitand, 2);
472 bind!(__quantum__rt__bigint_bitnot, 1);
473 bind!(__quantum__rt__bigint_bitor, 2);
474 bind!(__quantum__rt__bigint_bitxor, 2);
475 bind!(__quantum__rt__bigint_create_array, 2);
476 bind!(__quantum__rt__bigint_create_i64, 1);
477 bind!(__quantum__rt__bigint_divide, 2);
478 bind!(__quantum__rt__bigint_equal, 2);
479 bind!(__quantum__rt__bigint_get_data, 1);
480 bind!(__quantum__rt__bigint_get_length, 1);
481 bind!(__quantum__rt__bigint_greater, 2);
482 bind!(__quantum__rt__bigint_greater_eq, 2);
483 bind!(__quantum__rt__bigint_modulus, 2);
484 bind!(__quantum__rt__bigint_multiply, 2);
485 bind!(__quantum__rt__bigint_negate, 1);
486 bind!(__quantum__rt__bigint_power, 2);
487 bind!(__quantum__rt__bigint_shiftleft, 2);
488 bind!(__quantum__rt__bigint_shiftright, 2);
489 bind!(__quantum__rt__bigint_subtract, 2);
490 bind!(__quantum__rt__bigint_to_string, 1);
491 bind!(__quantum__rt__bigint_update_reference_count, 2);
492 bind!(__quantum__rt__bool_to_string, 1);
493 bind!(__quantum__rt__callable_copy, 2);
494 bind!(__quantum__rt__callable_create, 3);
495 bind!(__quantum__rt__callable_invoke, 3);
496 bind!(__quantum__rt__callable_make_adjoint, 1);
497 bind!(__quantum__rt__callable_make_controlled, 1);
498 bind!(__quantum__rt__callable_update_alias_count, 2);
499 bind!(__quantum__rt__callable_update_reference_count, 2);
500 bind!(__quantum__rt__capture_update_alias_count, 2);
501 bind!(__quantum__rt__capture_update_reference_count, 2);
502 bind!(__quantum__rt__double_to_string, 1);
503 bind!(__quantum__rt__fail, 1);
504 bind!(__quantum__rt__int_to_string, 1);
505 bind!(__quantum__rt__memory_allocate, 1);
506 bind!(__quantum__rt__message, 1);
507 bind!(__quantum__rt__pauli_to_string, 1);
508 bind!(__quantum__rt__qubit_allocate, 0);
509 bind!(__quantum__rt__qubit_allocate_array, 1);
510 bind!(__quantum__rt__qubit_release, 1);
511 bind!(__quantum__rt__qubit_release_array, 1);
512 bind!(__quantum__rt__qubit_to_string, 1);
513 bind!(__quantum__rt__read_result, 1);
514 bind!(__quantum__rt__result_equal, 2);
515 bind!(quantum__rt__range_to_string, 1);
516 bind!(__quantum__rt__result_get_one, 0);
517 bind!(__quantum__rt__result_get_zero, 0);
518 bind!(__quantum__rt__result_to_string, 1);
519 bind!(__quantum__rt__result_update_reference_count, 2);
520 bind!(__quantum__rt__string_concatenate, 2);
521 bind!(__quantum__rt__string_create, 1);
522 bind!(__quantum__rt__string_equal, 2);
523 bind!(__quantum__rt__string_get_data, 1);
524 bind!(__quantum__rt__string_get_length, 1);
525 bind!(__quantum__rt__string_update_reference_count, 2);
526 bind!(__quantum__rt__tuple_copy, 2);
527 bind!(__quantum__rt__tuple_create, 1);
528 bind!(__quantum__rt__tuple_update_alias_count, 2);
529 bind!(__quantum__rt__tuple_update_reference_count, 2);
530
531 if !(uses_legacy.iter().filter_map(|&b| b).all(|b| b)
532 || uses_legacy.iter().filter_map(|&b| b).all(|b| !b))
533 {
534 Err("Use of legacy and current output recording functions in the same program is not supported".to_string())
535 } else if declarations.is_empty() {
536 Ok(())
537 } else {
538 let keys = declarations.keys().collect::<Vec<_>>();
539 let (first, rest) = keys
540 .split_first()
541 .expect("Declarations list should be non-empty.");
542 Err(format!(
543 "Failed to link some declared functions: {}",
544 rest.iter().fold((*first).to_string(), |mut accum, f| {
545 accum.push_str(", ");
546 accum.push_str(f);
547 accum
548 })
549 ))
550 }
551}