qir_stdlib/
callables.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::{
5    arrays::{
6        __quantum__rt__array_concatenate, __quantum__rt__array_update_reference_count, QirArray,
7    },
8    tuples::{__quantum__rt__tuple_copy, __quantum__rt__tuple_update_reference_count},
9    update_counts,
10};
11use std::{cell::RefCell, mem::ManuallyDrop, rc::Rc};
12
13#[derive(Clone)]
14pub struct Callable {
15    func_table: *mut *mut u8,
16    mem_table: *mut *mut u8,
17    cap_tuple: *mut u8,
18    is_adj: RefCell<bool>,
19    ctls_count: RefCell<u32>,
20}
21
22#[unsafe(no_mangle)]
23pub extern "C" fn __quantum__rt__callable_create(
24    func_table: *mut *mut u8,
25    mem_table: *mut *mut u8,
26    cap_tuple: *mut u8,
27) -> *const Callable {
28    Rc::into_raw(Rc::new(Callable {
29        func_table,
30        mem_table,
31        cap_tuple,
32        is_adj: RefCell::new(false),
33        ctls_count: RefCell::new(0),
34    }))
35}
36
37#[unsafe(no_mangle)]
38#[allow(clippy::cast_ptr_alignment)]
39pub unsafe extern "C" fn __quantum__rt__callable_invoke(
40    callable: *const Callable,
41    args_tup: *mut u8,
42    res_tup: *mut u8,
43) {
44    unsafe {
45        let call = &*callable;
46        let index = usize::from(*call.is_adj.borrow())
47            + (if *call.ctls_count.borrow() > 0 { 2 } else { 0 });
48
49        // Collect any nested controls into a single control list.
50        let mut args_copy: *mut *const Vec<u8> = std::ptr::null_mut();
51        if !args_tup.is_null() {
52            // Copy the tuple so we can potentially edit it.
53            args_copy = __quantum__rt__tuple_copy(args_tup.cast::<*const Vec<u8>>(), true);
54
55            if *call.ctls_count.borrow() > 0 {
56                // If there are any controls, increment the reference count on the control list. This is just
57                // to balance the decrement that will happen in the loop and at the end of invoking the callable
58                // to ensure the original, non-owned list does not get incorrectly cleaned up.
59                __quantum__rt__array_update_reference_count(
60                    *args_copy.cast::<*const QirArray>(),
61                    1,
62                );
63
64                let mut ctl_count = *call.ctls_count.borrow();
65                while ctl_count > 1 {
66                    let ctls = *args_copy.cast::<*const QirArray>();
67                    let inner_tuple = *args_copy
68                        .cast::<*const QirArray>()
69                        .wrapping_add(1)
70                        .cast::<*mut *const Vec<u8>>();
71                    let inner_ctls = *inner_tuple.cast::<*const QirArray>();
72                    let new_ctls = __quantum__rt__array_concatenate(ctls, inner_ctls);
73                    let new_args = __quantum__rt__tuple_copy(inner_tuple, true);
74                    *new_args.cast::<*const QirArray>() = new_ctls;
75
76                    // Decrementing the reference count is either the extra count added above or the new
77                    // list created when performing concatenate above. In the latter case, the concatenated
78                    // list will get cleaned up, preventing memory from leaking.
79                    __quantum__rt__array_update_reference_count(
80                        *args_copy.cast::<*const QirArray>(),
81                        -1,
82                    );
83                    // Decrement the count on the copy to clean it up as well, since we created a new copy
84                    // with the updated controls list.
85                    __quantum__rt__tuple_update_reference_count(args_copy, -1);
86                    args_copy = new_args;
87                    ctl_count -= 1;
88                }
89            }
90        }
91
92        (*call
93            .func_table
94            .wrapping_add(index)
95            .cast::<extern "C" fn(*mut u8, *mut u8, *mut u8)>())(
96            call.cap_tuple,
97            args_copy.cast::<u8>(),
98            res_tup,
99        );
100        if *call.ctls_count.borrow() > 0 {
101            __quantum__rt__array_update_reference_count(*args_copy.cast::<*const QirArray>(), -1);
102        }
103        if !args_copy.is_null() {
104            __quantum__rt__tuple_update_reference_count(args_copy, -1);
105        }
106    }
107}
108
109#[unsafe(no_mangle)]
110pub unsafe extern "C" fn __quantum__rt__callable_copy(
111    callable: *const Callable,
112    force: bool,
113) -> *const Callable {
114    unsafe {
115        let rc = ManuallyDrop::new(Rc::from_raw(callable));
116        if force || Rc::weak_count(&rc) > 0 {
117            let copy = rc.as_ref().clone();
118            Rc::into_raw(Rc::new(copy))
119        } else {
120            let _ = Rc::into_raw(Rc::clone(&rc));
121            callable
122        }
123    }
124}
125
126#[unsafe(no_mangle)]
127pub unsafe extern "C" fn __quantum__rt__callable_make_adjoint(callable: *const Callable) {
128    unsafe {
129        let call = &*callable;
130        call.is_adj.replace_with(|&mut old| !old);
131    }
132}
133
134#[unsafe(no_mangle)]
135pub unsafe extern "C" fn __quantum__rt__callable_make_controlled(callable: *const Callable) {
136    unsafe {
137        let call = &*callable;
138        call.ctls_count.replace_with(|&mut old| old + 1);
139    }
140}
141
142#[unsafe(no_mangle)]
143pub unsafe extern "C" fn __quantum__rt__callable_update_reference_count(
144    callable: *const Callable,
145    update: i32,
146) {
147    unsafe {
148        update_counts(callable, update, false);
149    }
150}
151
152#[unsafe(no_mangle)]
153pub unsafe extern "C" fn __quantum__rt__callable_update_alias_count(
154    callable: *const Callable,
155    update: i32,
156) {
157    unsafe {
158        update_counts(callable, update, true);
159    }
160}
161
162#[unsafe(no_mangle)]
163pub unsafe extern "C" fn __quantum__rt__capture_update_reference_count(
164    callable: *const Callable,
165    update: i32,
166) {
167    unsafe {
168        let call = &*callable;
169        if !call.mem_table.is_null() && !(*(call.mem_table)).is_null() {
170            (*call.mem_table.cast::<extern "C" fn(*mut u8, i32)>())(call.cap_tuple, update);
171        }
172    }
173}
174
175#[unsafe(no_mangle)]
176pub unsafe extern "C" fn __quantum__rt__capture_update_alias_count(
177    callable: *const Callable,
178    update: i32,
179) {
180    unsafe {
181        let call = &*callable;
182        if !call.mem_table.is_null() && !(*(call.mem_table.wrapping_add(1))).is_null() {
183            (*call
184                .mem_table
185                .wrapping_add(1)
186                .cast::<extern "C" fn(*mut u8, i32)>())(call.cap_tuple, update);
187        }
188    }
189}