Skip to content

Commit

Permalink
fix: set langchain template variables as dict (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Feb 8, 2024
1 parent 5b0bf41 commit a7e2679
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from langchain.chains import LLMChain
from langchain_community.llms import OpenAI
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
from openinference.instrumentation.langchain import LangChainInstrumentor
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor

resource = Resource(attributes={})
tracer_provider = trace_sdk.TracerProvider(resource=resource)
trace_api.set_tracer_provider(tracer_provider=tracer_provider)
span_otlp_exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:6006/v1/traces")
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_otlp_exporter))
span_console_exporter = ConsoleSpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_console_exporter))
endpoint = "http://127.0.0.1:6006/v1/traces"
tracer_provider = trace_sdk.TracerProvider()
trace_api.set_tracer_provider(tracer_provider)
tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint=endpoint)))
tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter()))

LangChainInstrumentor().instrument()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor

resource = Resource(attributes={})
tracer_provider = trace_sdk.TracerProvider(resource=resource)
trace_api.set_tracer_provider(tracer_provider=tracer_provider)
span_otlp_exporter = OTLPSpanExporter(endpoint="http://127.0.0.1:6006/v1/traces")
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_otlp_exporter))
span_console_exporter = ConsoleSpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter=span_console_exporter))
endpoint = "http://127.0.0.1:6006/v1/traces"
tracer_provider = trace_sdk.TracerProvider()
trace_api.set_tracer_provider(tracer_provider)
tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint=endpoint)))
tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter()))

LangChainInstrumentor().instrument()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _update_span(span: trace_api.Span, run: Run) -> None:
_prompts(run.inputs),
_input_messages(run.inputs),
_output_messages(run.outputs),
_prompt_template(run.serialized),
_prompt_template(run),
_invocation_parameters(run),
_model_name(run.extra),
_token_counts(run.outputs),
Expand Down Expand Up @@ -353,11 +353,12 @@ def _get_tool_call(tool_call: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str


@stop_on_exception
def _prompt_template(serialized: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, Any]]:
def _prompt_template(run: Run) -> Iterator[Tuple[str, Any]]:
"""
A best-effort attempt to locate the PromptTemplate object among the
keyword arguments of a serialized object, e.g. an LLMChain object.
"""
serialized: Optional[Mapping[str, Any]] = run.serialized
if not serialized:
return
assert hasattr(serialized, "get"), f"expected Mapping, found {type(serialized)}"
Expand All @@ -383,7 +384,12 @@ def _prompt_template(serialized: Optional[Mapping[str, Any]]) -> Iterator[Tuple[
assert isinstance(
input_variables, list
), f"expected list, found {type(input_variables)}"
yield LLM_PROMPT_TEMPLATE_VARIABLES, input_variables
template_variables = {}
for variable in input_variables:
if value := run.inputs.get(variable):
template_variables[variable] = value
if template_variables:
yield LLM_PROMPT_TEMPLATE_VARIABLES, json.dumps(template_variables)
break


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import logging
import random
from contextlib import suppress
Expand Down Expand Up @@ -27,7 +28,6 @@
)
from opentelemetry import trace as trace_api
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
Expand Down Expand Up @@ -55,6 +55,8 @@ def test_callback_llm(
completion_usage: Dict[str, Any],
) -> None:
question = randstr()
template = "{context}{question}"
prompt = PromptTemplate(input_variables=["context", "question"], template=template)
output_messages: List[Dict[str, Any]] = (
chat_completion_mock_stream[1] if is_stream else [{"role": randstr(), "content": randstr()}]
)
Expand Down Expand Up @@ -82,7 +84,11 @@ def test_callback_llm(
texts=documents,
embeddings=FakeEmbeddings(size=2),
)
rqa = RetrievalQA.from_chain_type(llm=chat_model, retriever=retriever)
rqa = RetrievalQA.from_chain_type(
llm=chat_model,
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
)
with suppress(openai.BadRequestError):
if is_async:
asyncio.run(rqa.ainvoke({"query": question}))
Expand Down Expand Up @@ -150,6 +156,15 @@ def test_callback_llm(
assert llm_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == CHAIN.value
assert llm_attributes.pop(INPUT_VALUE, None) is not None
assert llm_attributes.pop(INPUT_MIME_TYPE, None) == JSON.value
assert llm_attributes.pop(LLM_PROMPT_TEMPLATE, None) == template
assert isinstance(
template_variables_json_string := llm_attributes.pop(LLM_PROMPT_TEMPLATE_VARIABLES, None),
str,
)
assert json.loads(template_variables_json_string) == {
"context": "\n\n".join(documents),
"question": question,
}
if status_code == 200:
assert llm_attributes.pop(OUTPUT_VALUE, None) == output_messages[0]["content"]
elif status_code == 400:
Expand Down Expand Up @@ -275,10 +290,8 @@ def in_memory_span_exporter() -> InMemorySpanExporter:

@pytest.fixture(scope="module")
def tracer_provider(in_memory_span_exporter: InMemorySpanExporter) -> trace_api.TracerProvider:
resource = Resource(attributes={})
tracer_provider = trace_sdk.TracerProvider(resource=resource)
span_processor = SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
tracer_provider.add_span_processor(span_processor=span_processor)
tracer_provider = trace_sdk.TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter))
return tracer_provider


Expand Down Expand Up @@ -362,6 +375,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
LLM_PROMPTS = SpanAttributes.LLM_PROMPTS
LLM_PROMPT_TEMPLATE = SpanAttributes.LLM_PROMPT_TEMPLATE
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
Expand Down

0 comments on commit a7e2679

Please sign in to comment.