From dc44128c7ff80438983f73fdfc8b62179468e729 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 21 Apr 2022 18:39:06 +0100 Subject: [PATCH] fix: context not propagated if request_span is disabled (#39) --- .github/workflows/ci.yml | 3 ++ reqwest-tracing/Cargo.toml | 2 + reqwest-tracing/src/middleware.rs | 85 ++++++++++++++++--------------- reqwest-tracing/src/otel.rs | 66 ++++++++++++++++++++++-- 4 files changed, 113 insertions(+), 43 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ebbdee..8af77ad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,7 @@ jobs: - opentelemetry_0_14 - opentelemetry_0_15 - opentelemetry_0_16 + - opentelemetry_0_17 steps: - name: Checkout repository uses: actions/checkout@v2 @@ -59,6 +60,7 @@ jobs: - opentelemetry_0_14 - opentelemetry_0_15 - opentelemetry_0_16 + - opentelemetry_0_17 steps: - name: Checkout repository uses: actions/checkout@v2 @@ -85,6 +87,7 @@ jobs: - opentelemetry_0_14 - opentelemetry_0_15 - opentelemetry_0_16 + - opentelemetry_0_17 steps: - name: Checkout repository uses: actions/checkout@v2 diff --git a/reqwest-tracing/Cargo.toml b/reqwest-tracing/Cargo.toml index a8499d3..c9b16f8 100644 --- a/reqwest-tracing/Cargo.toml +++ b/reqwest-tracing/Cargo.toml @@ -41,3 +41,5 @@ tracing-opentelemetry_0_17_pkg = { package = "tracing-opentelemetry",version = " [dev-dependencies] wiremock = "0.5" tokio = { version = "1", features = ["macros"] } +tracing_subscriber_0_2 = { package = "tracing-subscriber", version = "0.2" } +tracing_subscriber_0_3 = { package = "tracing-subscriber", version = "0.3" } diff --git a/reqwest-tracing/src/middleware.rs b/reqwest-tracing/src/middleware.rs index 2d31f56..5e8b684 100644 --- a/reqwest-tracing/src/middleware.rs +++ b/reqwest-tracing/src/middleware.rs @@ -2,6 +2,7 @@ use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Request, Response, StatusCode as RequestStatusCode}; use reqwest_middleware::{Error, Middleware, Next, Result}; use task_local_extensions::Extensions; +use tracing::Instrument; /// Middleware for tracing requests using the current Opentelemetry Context. pub struct TracingMiddleware; @@ -38,56 +39,60 @@ impl Middleware for TracingMiddleware { ) }; - // Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. - // Spans added by downstream consumers will be part of the same trace. - #[cfg(any( - feature = "opentelemetry_0_13", - feature = "opentelemetry_0_14", - feature = "opentelemetry_0_15", - feature = "opentelemetry_0_16", - feature = "opentelemetry_0_17", - ))] - let req = crate::otel::inject_opentelemetry_context_into_request(&request_span, req); + async { + // Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. + // Spans added by downstream consumers will be part of the same trace. + #[cfg(any( + feature = "opentelemetry_0_13", + feature = "opentelemetry_0_14", + feature = "opentelemetry_0_15", + feature = "opentelemetry_0_16", + feature = "opentelemetry_0_17", + ))] + let req = crate::otel::inject_opentelemetry_context_into_request(req); - // Run the request - let outcome = next.run(req, extensions).await; - match &outcome { - Ok(response) => { - // The request ran successfully - let span_status = get_span_status(response.status()); - let status_code = response.status().as_u16() as i64; - let user_agent = get_header_value("user_agent", response.headers()); - if let Some(span_status) = span_status { - request_span.record("otel.status_code", &span_status); + // Run the request + let outcome = next.run(req, extensions).await; + match &outcome { + Ok(response) => { + // The request ran successfully + let span_status = get_span_status(response.status()); + let status_code = response.status().as_u16() as i64; + let user_agent = get_header_value("user_agent", response.headers()); + if let Some(span_status) = span_status { + request_span.record("otel.status_code", &span_status); + } + request_span.record("http.status_code", &status_code); + request_span.record("http.user_agent", &user_agent.as_str()); } - request_span.record("http.status_code", &status_code); - request_span.record("http.user_agent", &user_agent.as_str()); - } - Err(e) => { - // The request didn't run successfully - let error_message = e.to_string(); - let error_cause_chain = format!("{:?}", e); - request_span.record("otel.status_code", &"ERROR"); - request_span.record("error.message", &error_message.as_str()); - request_span.record("error.cause_chain", &error_cause_chain.as_str()); - if let Error::Reqwest(e) = e { - request_span.record( - "http.status_code", - &e.status() - .map(|s| s.to_string()) - .unwrap_or_else(|| "".to_string()) - .as_str(), - ); + Err(e) => { + // The request didn't run successfully + let error_message = e.to_string(); + let error_cause_chain = format!("{:?}", e); + request_span.record("otel.status_code", &"ERROR"); + request_span.record("error.message", &error_message.as_str()); + request_span.record("error.cause_chain", &error_cause_chain.as_str()); + if let Error::Reqwest(e) = e { + request_span.record( + "http.status_code", + &e.status() + .map(|s| s.to_string()) + .unwrap_or_else(|| "".to_string()) + .as_str(), + ); + } } } + outcome } - outcome + .instrument(request_span.clone()) + .await } } fn get_header_value(key: &str, headers: &HeaderMap) -> String { let header_default = &HeaderValue::from_static(""); - format!("{:?}", headers.get(key).unwrap_or(header_default)).replace("\"", "") + format!("{:?}", headers.get(key).unwrap_or(header_default)).replace('"', "") } /// HTTP Mapping diff --git a/reqwest-tracing/src/otel.rs b/reqwest-tracing/src/otel.rs index 41d48c2..e7c518c 100644 --- a/reqwest-tracing/src/otel.rs +++ b/reqwest-tracing/src/otel.rs @@ -38,9 +38,8 @@ use opentelemetry::propagation::Injector; use tracing_opentelemetry::OpenTelemetrySpanExt; /// Injects the given OpenTelemetry Context into a reqwest::Request headers to allow propagation downstream. -pub fn inject_opentelemetry_context_into_request(span: &Span, request: Request) -> Request { - let context = span.context(); - let mut request = request; +pub fn inject_opentelemetry_context_into_request(mut request: Request) -> Request { + let context = Span::current().context(); global::get_text_map_propagator(|injector| { injector.inject_context(&context, &mut RequestCarrier::new(&mut request)) @@ -72,3 +71,64 @@ impl<'a> Injector for RequestCarrier<'a> { self.request.headers_mut().insert(header_name, header_value); } } + +#[cfg(test)] +mod test { + use super::*; + use crate::TracingMiddleware; + use opentelemetry::sdk::propagation::TraceContextPropagator; + use reqwest_middleware::ClientBuilder; + use tracing::{info_span, Instrument, Level}; + #[cfg(any( + feature = "opentelemetry_0_13", + feature = "opentelemetry_0_14", + feature = "opentelemetry_0_15" + ))] + use tracing_subscriber_0_2::{filter, layer::SubscriberExt, Registry}; + #[cfg(not(any( + feature = "opentelemetry_0_13", + feature = "opentelemetry_0_14", + feature = "opentelemetry_0_15" + )))] + use tracing_subscriber_0_3::{filter, layer::SubscriberExt, Registry}; + use wiremock::{matchers::any, Mock, MockServer, ResponseTemplate}; + + #[tokio::test] + async fn tracing_middleware_propagates_otel_data_even_when_the_span_is_disabled() { + let tracer = opentelemetry::sdk::export::trace::stdout::new_pipeline() + .with_writer(std::io::sink()) + .install_simple(); + let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); + let subscriber = Registry::default() + .with(filter::Targets::new().with_target("reqwest_tracing::otel::test", Level::DEBUG)) + .with(telemetry); + tracing::subscriber::set_global_default(subscriber).unwrap(); + global::set_text_map_propagator(TraceContextPropagator::new()); + + // Mock server - sends all request headers back in the response + let server = MockServer::start().await; + Mock::given(any()) + .respond_with(|req: &wiremock::Request| { + req.headers + .iter() + .fold(ResponseTemplate::new(200), |resp, (k, v)| { + resp.append_header(k.clone(), v.clone()) + }) + }) + .mount(&server) + .await; + + let client = ClientBuilder::new(reqwest::Client::new()) + .with(TracingMiddleware) + .build(); + + let resp = client + .get(server.uri()) + .send() + .instrument(info_span!("some_span")) + .await + .unwrap(); + + assert!(resp.headers().contains_key("traceparent")); + } +}