input() не работает в numba
В jit-функции numba я использую input(), и на этом все вылетает. Мне надо ускорить код, при этом input критически важен. Как мне это исправить. Мне надо сохранить ввод из консоли и оставить input.
P.S. Без numba все работает отлично, но медленно. P.P.S Это регистровая VR-машина
Код:
import numba, numpy
def smart_spl(x):
return x.split(';')[0]
def load(text):
strings = list(map(smart_spl, text.split('\n')))
data = list(map(lambda x: x.split(':'), strings))
data2 = list(map(lambda x: x[0], data))
data = list(map(lambda x: list(map(int, x[1:])), data))
data += data
return data, data2
# @numba.jit(cache=True)
def exec(data=None, data2=None, memory=None):
current_string = 0
while current_string < len(data2):
#print('DEBUG:', current_string, data2[current_string], data[current_string])
if data2[current_string] == 'ifgotofwd':
if memory[int(data[current_string][1])]:
current_string += int(data[current_string][0])
continue
if data2[current_string] == 'gotoback':
current_string -= int(data[current_string][0])
continue
elif data2[current_string] == 'define':
define(int(data[current_string][0]), memory=memory)
elif data2[current_string] == 'setval':
setval(int(data[current_string][0]), int(data[current_string][1]), memory=memory)
elif data2[current_string] == 'add':
add(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'mul':
mul(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'div':
div(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'sub':
sub(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'gre':
gre(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'eql':
eql(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'and':
and_op(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'or':
or_op(int(data[current_string][0]), int(data[current_string][1]),
int(data[current_string][2]), memory=memory)
elif data2[current_string] == 'not':
not_op(int(data[current_string][0]), int(data[current_string][1]), memory=memory)
elif data2[current_string] == 'print':
print_func(int(data[current_string][0]), memory=memory)
elif data2[current_string] == 'input':
# print('DEBUG:INPUT')
input_func(int(data[current_string][0]), memory=memory)
current_string += 1
@numba.jit(cache=True)
def input_func(adress, data=None, memory=None):
setval(adress, int(input()), memory=memory)
@numba.njit(cache=True)
def print_func(adress, data=None, memory=None):
print(int(memory[adress]))
@numba.njit(cache=True)
def graphic(data=None, memory=None):
pass
@numba.njit(cache=True)
def console():
pass
@numba.njit(cache=True)
def setpixel(x, y, r, g, b, data=None, memory=None):
pass
@numba.njit(cache=True)
def define(adress, size=0, data=None, memory=None):
memory[adress] = 0
@numba.njit(cache=True)
def setval(adress, value, data=None, memory=0):
memory[adress] = value
@numba.njit(cache=True)
def add(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = memory[adr1] + memory[adr2]
@numba.njit(cache=True)
def mul(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = memory[adr1] * memory[adr2]
@numba.njit(cache=True)
def sub(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = memory[adr1] - memory[adr2]
@numba.njit(cache=True)
def div(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = memory[adr1] // memory[adr2]
@numba.njit(cache=True)
def gre(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = int(memory[adr1] > memory[adr2])
@numba.njit(cache=True)
def eql(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = int(memory[adr1] == memory[adr2])
@numba.njit(cache=True)
def or_op(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = int(memory[adr1] or memory[adr2])
@numba.njit(cache=True)
def and_op(adr1, adr2, adr_out, data=None, memory=None):
memory[adr_out] = int(memory[adr1] and memory[adr2])
@numba.njit(cache=True)
def not_op(adr1, adr_out, data=None, memory=None):
memory[adr_out] = int(not bool(memory[adr1]))
if __name__ == '__main__':
text = open('hello.roasm').read()
data, data2 = load(text)
memory = numpy.zeros(1024)
exec(data, data2, memory)
Ошибка:
NumbaWarning:
Compilation is falling back to object mode WITH looplifting enabled because Function "input_func" failed type inference due to: Untyped global name 'input': Cannot determine Numba type of <class 'builtin_function_or_method'>
File "roasm.py", line 70:
def input_func(adress, data=None, memory=None):
setval(adress, int(input()), memory=memory)
^
@numba.jit(cache=True)
C:\Users\Vasya\PycharmProjects\RobinN\venv\lib\site-packages\numba\core\object_mode_passes.py:151: NumbaWarning: Function "input_func" was compiled in object mode without forceobj=True.
File "roasm.py", line 69:
@numba.jit(cache=True)
def input_func(adress, data=None, memory=None):
^
warnings.warn(errors.NumbaWarning(warn_msg,
C:\Users\Vasya\PycharmProjects\RobinN\venv\lib\site-packages\numba\core\object_mode_passes.py:161: NumbaDeprecationWarning:
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.
For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit
File "roasm.py", line 69:
@numba.jit(cache=True)
def input_func(adress, data=None, memory=None):
^
warnings.warn(errors.NumbaDeprecationWarning(msg,
Файл hello.roasm:
define:10:0:0; создаем переменную
setval:1:1:0;
setval:0:0:0;
input:10:0:0;
define:11:0:0;
define:21:0:0;
define:31:0:0;
define:20:0:0;
gre:10:20:11;
eql:0:1:21;
or:21:11:31;
not:31:11:0;
not:11:21:0;
ifgotofwd:4:11:0;
add:20:1:20;
print:20:0:0;
gotoback:8:0:0;
add:1:1:1;