Skip to content

Commit

Permalink
Port remaining two passes
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Aug 27, 2024
1 parent c088cc2 commit 144f8bd
Show file tree
Hide file tree
Showing 5 changed files with 791 additions and 59 deletions.
282 changes: 282 additions & 0 deletions ptx/src/pass/extract_globals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
use super::*;

pub(super) fn run<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
match statement {
Statement::Variable(
var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfe { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfi { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
let fn_name: String =
[ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Brev { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Activemask { arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::IncrementWrap,
semantics,
scope,
space,
..
},
arguments,
}) => {
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_inc",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::DecrementWrap,
semantics,
scope,
space,
..
},
arguments,
}) => {
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_dec",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::FloatAdd,
semantics,
scope,
space,
..
},
arguments,
}) => {
let scalar_type = match data.type_ {
ptx_parser::Type::Scalar(scalar) => scalar,
_ => return Err(error_unreachable()),
};
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_add_",
scalar_to_ptx_name(scalar_type),
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
s => local.push(s),
}
}
Ok((local, global))
}

fn instruction_to_fn_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
inst: ast::Instruction<SpirvWord>,
fn_name: String,
) -> Result<ExpandedStatement, TranslateError> {
let mut arguments = Vec::new();
ast::visit_map(inst, &mut |operand,
type_space: Option<(
&ast::Type,
ast::StateSpace,
)>,
is_dst,
_| {
let (typ, space) = match type_space {
Some((typ, space)) => (typ.clone(), space),
None => return Err(error_unreachable()),
};
arguments.push((operand, is_dst, typ, space));
Ok(SpirvWord(0))
})?;
let return_arguments_count = arguments
.iter()
.position(|(desc, is_dst, _, _)| !is_dst)
.unwrap_or(arguments.len());
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
let fn_id = register_external_fn_call(
id_defs,
ptx_impl_imports,
fn_name,
return_arguments
.iter()
.map(|(_, _, typ, state)| (typ, *state)),
input_arguments
.iter()
.map(|(_, _, typ, state)| (typ, *state)),
)?;
Ok(Statement::Instruction(ast::Instruction::Call {
data: ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, _, typ, state)| (typ.clone(), *state))
.collect::<Vec<_>>(),
input_arguments: input_arguments
.iter()
.map(|(_, _, typ, state)| (typ.clone(), *state))
.collect::<Vec<_>>(),
},
arguments: ast::CallArgs {
return_arguments: return_arguments
.iter()
.map(|(name, _, _, _)| *name)
.collect::<Vec<_>>(),
func: fn_id,
input_arguments: input_arguments
.iter()
.map(|(name, _, _, _)| *name)
.collect::<Vec<_>>(),
},
}))
}

fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
match this {
ast::ScalarType::B8 => "b8",
ast::ScalarType::B16 => "b16",
ast::ScalarType::B32 => "b32",
ast::ScalarType::B64 => "b64",
ast::ScalarType::B128 => "b128",
ast::ScalarType::U8 => "u8",
ast::ScalarType::U16 => "u16",
ast::ScalarType::U16x2 => "u16x2",
ast::ScalarType::U32 => "u32",
ast::ScalarType::U64 => "u64",
ast::ScalarType::S8 => "s8",
ast::ScalarType::S16 => "s16",
ast::ScalarType::S16x2 => "s16x2",
ast::ScalarType::S32 => "s32",
ast::ScalarType::S64 => "s64",
ast::ScalarType::F16 => "f16",
ast::ScalarType::F16x2 => "f16x2",
ast::ScalarType::F32 => "f32",
ast::ScalarType::F64 => "f64",
ast::ScalarType::BF16 => "bf16",
ast::ScalarType::BF16x2 => "bf16x2",
ast::ScalarType::Pred => "pred",
}
}

fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
match this {
ast::AtomSemantics::Relaxed => "relaxed",
ast::AtomSemantics::Acquire => "acquire",
ast::AtomSemantics::Release => "release",
ast::AtomSemantics::AcqRel => "acq_rel",
}
}

fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
match this {
ast::MemScope::Cta => "cta",
ast::MemScope::Gpu => "gpu",
ast::MemScope::Sys => "sys",
ast::MemScope::Cluster => "cluster",
}
}

fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
match this {
ast::StateSpace::Generic => "generic",
ast::StateSpace::Global => "global",
ast::StateSpace::Shared => "shared",
ast::StateSpace::Reg => "reg",
ast::StateSpace::Const => "const",
ast::StateSpace::Local => "local",
ast::StateSpace::Param => "param",
ast::StateSpace::Sreg => "sreg",
ast::StateSpace::SharedCluster => "shared_cluster",
ast::StateSpace::ParamEntry => "param_entry",
ast::StateSpace::SharedCta => "shared_cta",
ast::StateSpace::ParamFunc => "param_func",
}
}
53 changes: 0 additions & 53 deletions ptx/src/pass/fix_special_registers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,56 +128,3 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
}
}
}

fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
let func_decl = ast::MethodDeclaration::<SpirvWord> {
return_arguments,
name: ast::MethodName::Func(fn_id),
input_arguments,
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
};
entry.insert(Directive::Method(func));
Ok(fn_id)
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
ast::MethodName::Func(fn_id) => Ok(fn_id),
ast::MethodName::Kernel(_) => Err(error_unreachable()),
},
_ => Err(error_unreachable()),
},
}
}

fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<SpirvWord>> {
args.map(|(typ, space)| ast::Variable {
align: None,
v_type: typ.clone(),
state_space: space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}
Loading

0 comments on commit 144f8bd

Please sign in to comment.