diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 4394a7fe997c..1ed0008c8dba 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -2839,11 +2839,17 @@ void TypeAnalyzer::visitExtractValueInst(ExtractValueInst &I) { int off = (int)ai.getLimitedValue(); int size = dl.getTypeSizeInBits(I.getType()) / 8; - if (direction & DOWN) - updateAnalysis(&I, - getAnalysis(I.getOperand(0)) - .ShiftIndices(dl, off, size, /*addOffset*/ 0), - &I); + if (direction & DOWN) { + auto res = getAnalysis(I.getOperand(0)) + .ShiftIndices(dl, off, size, /*addOffset*/ 0); + // If no type info was propagated from the operand (e.g., the operand is + // an opaque call result with no known type info), fall back to using the + // LLVM type of the result to generate type info. For extractvalue, the + // LLVM type of the extracted field is authoritative. + if (!res.isKnown()) + res = defaultTypeTreeForLLVM(I.getType(), &I, /*intIsPointer*/ false); + updateAnalysis(&I, res, &I); + } if (direction & UP) updateAnalysis(I.getOperand(0), diff --git a/enzyme/test/TypeAnalysis/extractvalue_opaque_call.ll b/enzyme/test/TypeAnalysis/extractvalue_opaque_call.ll new file mode 100644 index 000000000000..19ca5d661b95 --- /dev/null +++ b/enzyme/test/TypeAnalysis/extractvalue_opaque_call.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=caller -o /dev/null | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=caller -S -o /dev/null | FileCheck %s + +; Test that extractvalue from a call returning a struct with float arrays +; properly propagates type information even when the aggregate operand has +; no known type (e.g., from an opaque function call result). + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +%struct.c_v = type { [2 x float], [2 x [3 x float]] } + +declare %struct.c_v @pre_work() + +define void @caller() { +entry: + %cv = call %struct.c_v @pre_work() + %f2 = extractvalue %struct.c_v %cv, 0 + %f2x3 = extractvalue %struct.c_v %cv, 1 + %f0 = extractvalue [2 x float] %f2, 0 + %f1 = extractvalue [2 x float] %f2, 1 + %f00 = extractvalue [2 x [3 x float]] %f2x3, 0, 0 + ret void +} + +; CHECK-LABEL: caller +; CHECK: %f2 = extractvalue %struct.c_v %cv, 0: {[-1]:Float@float} +; CHECK: %f2x3 = extractvalue %struct.c_v %cv, 1: {[-1]:Float@float} +; CHECK: %f0 = extractvalue [2 x float] %f2, 0: {[-1]:Float@float} +; CHECK: %f1 = extractvalue [2 x float] %f2, 1: {[-1]:Float@float} +; CHECK: %f00 = extractvalue [2 x [3 x float]] %f2x3, 0, 0: {[-1]:Float@float}