From 13bacc5c8f77fc98cf20d3d2c67d1f5dee1c84a6 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Wed, 17 Jul 2024 13:35:58 +0100 Subject: [PATCH] eliminate pmux in functional backend --- backends/functional/cxx.cc | 1 - backends/functional/cxx_runtime/sim.h | 19 ------------------- backends/functional/smtlib.cc | 6 ------ kernel/functionalir.cc | 11 +++++++++-- kernel/functionalir.h | 12 ------------ 5 files changed, 9 insertions(+), 40 deletions(-) diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index 920ccf363..0fa310164 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -136,7 +136,6 @@ template struct CxxPrintVisitor : public FunctionalIR::Abstra void logical_shift_right(Node, Node a, Node b) override { print("{} >> {}", a, b); } void arithmetic_shift_right(Node, Node a, Node b) override { print("{}.arithmetic_shift_right({})", a, b); } void mux(Node, Node a, Node b, Node s) override { print("{2}.any() ? {1} : {0}", a, b, s); } - void pmux(Node, Node a, Node b, Node s) override { print("{0}.pmux({1}, {2})", a, b, s); } void constant(Node, RTLIL::Const value) override { std::stringstream ss; bool multiple = value.size() > 32; diff --git a/backends/functional/cxx_runtime/sim.h b/backends/functional/cxx_runtime/sim.h index 950d4c5e1..ed1a25ed4 100644 --- a/backends/functional/cxx_runtime/sim.h +++ b/backends/functional/cxx_runtime/sim.h @@ -368,25 +368,6 @@ public: return ret; } - template - Signal pmux(Signal const &b, Signal const &s) const - { - bool found; - Signal ret; - - found = false; - ret = *this; - for(size_t i = 0; i < ns; i++){ - if(s._bits[i]){ - if(found) - return 0; - found = true; - ret = b.template slice(n * i); - } - } - return ret; - } - template Signal concat(Signal const& b) const { diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib.cc index c35a48dd8..a2bb6666c 100644 --- a/backends/functional/smtlib.cc +++ b/backends/functional/smtlib.cc @@ -174,12 +174,6 @@ struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor { SExpr logical_shift_right(Node, Node a, Node b) override { return list("bvlshr", n(a), extend(n(b), b.width(), a.width())); } SExpr arithmetic_shift_right(Node, Node a, Node b) override { return list("bvashr", n(a), extend(n(b), b.width(), a.width())); } SExpr mux(Node, Node a, Node b, Node s) override { return list("ite", to_bool(n(s)), n(b), n(a)); } - SExpr pmux(Node, Node a, Node b, Node s) override { - SExpr rv = n(a); - for(int i = 0; i < s.width(); i++) - rv = list("ite", to_bool(extract(n(s), i)), extract(n(b), a.width() * i, a.width()), rv); - return rv; - } SExpr constant(Node, RTLIL::Const value) override { return literal(value); } SExpr memory_read(Node, Node mem, Node addr) override { return list("select", n(mem), n(addr)); } SExpr memory_write(Node, Node mem, Node addr, Node data) override { return list("store", n(mem), n(addr), n(data)); } diff --git a/kernel/functionalir.cc b/kernel/functionalir.cc index ddd821ada..4b2ad8c0b 100644 --- a/kernel/functionalir.cc +++ b/kernel/functionalir.cc @@ -52,7 +52,6 @@ const char *FunctionalIR::fn_to_string(FunctionalIR::Fn fn) { case FunctionalIR::Fn::logical_shift_right: return "logical_shift_right"; case FunctionalIR::Fn::arithmetic_shift_right: return "arithmetic_shift_right"; case FunctionalIR::Fn::mux: return "mux"; - case FunctionalIR::Fn::pmux: return "pmux"; case FunctionalIR::Fn::constant: return "constant"; case FunctionalIR::Fn::input: return "input"; case FunctionalIR::Fn::state: return "state"; @@ -165,6 +164,14 @@ private: return factory.mux(y0, y1, factory.slice(s, sn - 1, 1)); } } + Node handle_pmux(Node a, Node b, Node s) { + // TODO : what to do about multiple b bits set ? + log_assert(b.width() == a.width() * s.width()); + Node y = a; + for(int i = 0; i < s.width(); i++) + y = factory.mux(y, factory.slice(b, a.width() * i, a.width()), factory.slice(s, i, 1)); + return y; + } public: Node handle(IdString cellType, dict parameters, dict inputs) { @@ -266,7 +273,7 @@ public: }else if(cellType == ID($mux)){ return factory.mux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S))); }else if(cellType == ID($pmux)){ - return factory.pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S))); + return handle_pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S))); }else if(cellType == ID($concat)){ Node a = inputs.at(ID(A)); Node b = inputs.at(ID(B)); diff --git a/kernel/functionalir.h b/kernel/functionalir.h index 0559eae89..d0ec1bce6 100644 --- a/kernel/functionalir.h +++ b/kernel/functionalir.h @@ -119,11 +119,6 @@ public: arithmetic_shift_right, // mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a mux, - // pmux(a: bit[N], b: bit[N*M], s: bit[M]): bit[N] - // required: no more than one bit in b is set - // if s[i] = 1 for any i, then returns b[i * N +: N] - // returns a if s == 0 - pmux, // constant(a: Const[N]): bit[N] = a constant, // input(a: IdString): any @@ -277,7 +272,6 @@ public: case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1)); break; case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break; case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break; - case Fn::pmux: return v.pmux(*this, arg(0), arg(1), arg(2)); break; case Fn::constant: return v.constant(*this, _ref.function().as_const()); break; case Fn::input: return v.input(*this, _ref.function().as_idstring()); break; case Fn::state: return v.state(*this, _ref.function().as_idstring()); break; @@ -320,7 +314,6 @@ public: virtual T logical_shift_right(Node self, Node a, Node b) = 0; virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0; virtual T mux(Node self, Node a, Node b, Node s) = 0; - virtual T pmux(Node self, Node a, Node b, Node s) = 0; virtual T constant(Node self, RTLIL::Const value) = 0; virtual T input(Node self, IdString name) = 0; virtual T state(Node self, IdString name) = 0; @@ -359,7 +352,6 @@ public: T logical_shift_right(Node self, Node, Node) override { return default_handler(self); } T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); } T mux(Node self, Node, Node, Node) override { return default_handler(self); } - T pmux(Node self, Node, Node, Node) override { return default_handler(self); } T constant(Node self, RTLIL::Const) override { return default_handler(self); } T input(Node self, IdString) override { return default_handler(self); } T state(Node self, IdString) override { return default_handler(self); } @@ -451,10 +443,6 @@ public: log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1)); return add(Fn::mux, a.sort(), {a, b, s}); } - Node pmux(Node a, Node b, Node s) { - log_assert(a.sort().is_signal() && b.sort().is_signal() && s.sort().is_signal() && a.sort().width() * s.sort().width() == b.sort().width()); - return add(Fn::pmux, a.sort(), {a, b, s}); - } Node memory_read(Node mem, Node addr) { log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width()); return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr});