qir_stdlib/
callables.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::{
5    arrays::{
6        QirArray, __quantum__rt__array_concatenate, __quantum__rt__array_update_reference_count,
7    },
8    tuples::{__quantum__rt__tuple_copy, __quantum__rt__tuple_update_reference_count},
9    update_counts,
10};
11use std::{cell::RefCell, mem::ManuallyDrop, rc::Rc};
12
13#[derive(Clone)]
14pub struct Callable {
15    func_table: *mut *mut u8,
16    mem_table: *mut *mut u8,
17    cap_tuple: *mut u8,
18    is_adj: RefCell<bool>,
19    ctls_count: RefCell<u32>,
20}
21
22#[no_mangle]
23pub extern "C" fn __quantum__rt__callable_create(
24    func_table: *mut *mut u8,
25    mem_table: *mut *mut u8,
26    cap_tuple: *mut u8,
27) -> *const Callable {
28    Rc::into_raw(Rc::new(Callable {
29        func_table,
30        mem_table,
31        cap_tuple,
32        is_adj: RefCell::new(false),
33        ctls_count: RefCell::new(0),
34    }))
35}
36
37#[no_mangle]
38#[allow(clippy::cast_ptr_alignment)]
39pub unsafe extern "C" fn __quantum__rt__callable_invoke(
40    callable: *const Callable,
41    args_tup: *mut u8,
42    res_tup: *mut u8,
43) {
44    let call = &*callable;
45    let index =
46        usize::from(*call.is_adj.borrow()) + (if *call.ctls_count.borrow() > 0 { 2 } else { 0 });
47
48    // Collect any nested controls into a single control list.
49    let mut args_copy: *mut *const Vec<u8> = std::ptr::null_mut();
50    if !args_tup.is_null() {
51        // Copy the tuple so we can potentially edit it.
52        args_copy = __quantum__rt__tuple_copy(args_tup.cast::<*const Vec<u8>>(), true);
53
54        if *call.ctls_count.borrow() > 0 {
55            // If there are any controls, increment the reference count on the control list. This is just
56            // to balance the decrement that will happen in the loop and at the end of invoking the callable
57            // to ensure the original, non-owned list does not get incorrectly cleaned up.
58            __quantum__rt__array_update_reference_count(*args_copy.cast::<*const QirArray>(), 1);
59
60            let mut ctl_count = *call.ctls_count.borrow();
61            while ctl_count > 1 {
62                let ctls = *args_copy.cast::<*const QirArray>();
63                let inner_tuple = *args_copy
64                    .cast::<*const QirArray>()
65                    .wrapping_add(1)
66                    .cast::<*mut *const Vec<u8>>();
67                let inner_ctls = *inner_tuple.cast::<*const QirArray>();
68                let new_ctls = __quantum__rt__array_concatenate(ctls, inner_ctls);
69                let new_args = __quantum__rt__tuple_copy(inner_tuple, true);
70                *new_args.cast::<*const QirArray>() = new_ctls;
71
72                // Decrementing the reference count is either the extra count added above or the new
73                // list created when performing concatenate above. In the latter case, the concatenated
74                // list will get cleaned up, preventing memory from leaking.
75                __quantum__rt__array_update_reference_count(
76                    *args_copy.cast::<*const QirArray>(),
77                    -1,
78                );
79                // Decrement the count on the copy to clean it up as well, since we created a new copy
80                // with the updated controls list.
81                __quantum__rt__tuple_update_reference_count(args_copy, -1);
82                args_copy = new_args;
83                ctl_count -= 1;
84            }
85        }
86    }
87
88    (*call
89        .func_table
90        .wrapping_add(index)
91        .cast::<extern "C" fn(*mut u8, *mut u8, *mut u8)>())(
92        call.cap_tuple,
93        args_copy.cast::<u8>(),
94        res_tup,
95    );
96    if *call.ctls_count.borrow() > 0 {
97        __quantum__rt__array_update_reference_count(*args_copy.cast::<*const QirArray>(), -1);
98    }
99    if !args_copy.is_null() {
100        __quantum__rt__tuple_update_reference_count(args_copy, -1);
101    }
102}
103
104#[no_mangle]
105pub unsafe extern "C" fn __quantum__rt__callable_copy(
106    callable: *const Callable,
107    force: bool,
108) -> *const Callable {
109    let rc = ManuallyDrop::new(Rc::from_raw(callable));
110    if force || Rc::weak_count(&rc) > 0 {
111        let copy = rc.as_ref().clone();
112        Rc::into_raw(Rc::new(copy))
113    } else {
114        let _ = Rc::into_raw(Rc::clone(&rc));
115        callable
116    }
117}
118
119#[no_mangle]
120pub unsafe extern "C" fn __quantum__rt__callable_make_adjoint(callable: *const Callable) {
121    let call = &*callable;
122    call.is_adj.replace_with(|&mut old| !old);
123}
124
125#[no_mangle]
126pub unsafe extern "C" fn __quantum__rt__callable_make_controlled(callable: *const Callable) {
127    let call = &*callable;
128    call.ctls_count.replace_with(|&mut old| old + 1);
129}
130
131#[no_mangle]
132pub unsafe extern "C" fn __quantum__rt__callable_update_reference_count(
133    callable: *const Callable,
134    update: i32,
135) {
136    update_counts(callable, update, false);
137}
138
139#[no_mangle]
140pub unsafe extern "C" fn __quantum__rt__callable_update_alias_count(
141    callable: *const Callable,
142    update: i32,
143) {
144    update_counts(callable, update, true);
145}
146
147#[no_mangle]
148pub unsafe extern "C" fn __quantum__rt__capture_update_reference_count(
149    callable: *const Callable,
150    update: i32,
151) {
152    let call = &*callable;
153    if !call.mem_table.is_null() && !(*(call.mem_table)).is_null() {
154        (*call.mem_table.cast::<extern "C" fn(*mut u8, i32)>())(call.cap_tuple, update);
155    }
156}
157
158#[no_mangle]
159pub unsafe extern "C" fn __quantum__rt__capture_update_alias_count(
160    callable: *const Callable,
161    update: i32,
162) {
163    let call = &*callable;
164    if !call.mem_table.is_null() && !(*(call.mem_table.wrapping_add(1))).is_null() {
165        (*call
166            .mem_table
167            .wrapping_add(1)
168            .cast::<extern "C" fn(*mut u8, i32)>())(call.cap_tuple, update);
169    }
170}