mirror of
https://github.com/TrueLayer/reqwest-middleware.git
synced 2024-12-26 19:06:31 +00:00
fix: context not propagated if request_span is disabled (#39)
This commit is contained in:
parent
f928a7b2d6
commit
dc44128c7f
4 changed files with 113 additions and 43 deletions
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
|
@ -16,6 +16,7 @@ jobs:
|
|||
- opentelemetry_0_14
|
||||
- opentelemetry_0_15
|
||||
- opentelemetry_0_16
|
||||
- opentelemetry_0_17
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
@ -59,6 +60,7 @@ jobs:
|
|||
- opentelemetry_0_14
|
||||
- opentelemetry_0_15
|
||||
- opentelemetry_0_16
|
||||
- opentelemetry_0_17
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
@ -85,6 +87,7 @@ jobs:
|
|||
- opentelemetry_0_14
|
||||
- opentelemetry_0_15
|
||||
- opentelemetry_0_16
|
||||
- opentelemetry_0_17
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
|
|
@ -41,3 +41,5 @@ tracing-opentelemetry_0_17_pkg = { package = "tracing-opentelemetry",version = "
|
|||
[dev-dependencies]
|
||||
wiremock = "0.5"
|
||||
tokio = { version = "1", features = ["macros"] }
|
||||
tracing_subscriber_0_2 = { package = "tracing-subscriber", version = "0.2" }
|
||||
tracing_subscriber_0_3 = { package = "tracing-subscriber", version = "0.3" }
|
||||
|
|
|
@ -2,6 +2,7 @@ use reqwest::header::{HeaderMap, HeaderValue};
|
|||
use reqwest::{Request, Response, StatusCode as RequestStatusCode};
|
||||
use reqwest_middleware::{Error, Middleware, Next, Result};
|
||||
use task_local_extensions::Extensions;
|
||||
use tracing::Instrument;
|
||||
|
||||
/// Middleware for tracing requests using the current Opentelemetry Context.
|
||||
pub struct TracingMiddleware;
|
||||
|
@ -38,56 +39,60 @@ impl Middleware for TracingMiddleware {
|
|||
)
|
||||
};
|
||||
|
||||
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
|
||||
// Spans added by downstream consumers will be part of the same trace.
|
||||
#[cfg(any(
|
||||
feature = "opentelemetry_0_13",
|
||||
feature = "opentelemetry_0_14",
|
||||
feature = "opentelemetry_0_15",
|
||||
feature = "opentelemetry_0_16",
|
||||
feature = "opentelemetry_0_17",
|
||||
))]
|
||||
let req = crate::otel::inject_opentelemetry_context_into_request(&request_span, req);
|
||||
async {
|
||||
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
|
||||
// Spans added by downstream consumers will be part of the same trace.
|
||||
#[cfg(any(
|
||||
feature = "opentelemetry_0_13",
|
||||
feature = "opentelemetry_0_14",
|
||||
feature = "opentelemetry_0_15",
|
||||
feature = "opentelemetry_0_16",
|
||||
feature = "opentelemetry_0_17",
|
||||
))]
|
||||
let req = crate::otel::inject_opentelemetry_context_into_request(req);
|
||||
|
||||
// Run the request
|
||||
let outcome = next.run(req, extensions).await;
|
||||
match &outcome {
|
||||
Ok(response) => {
|
||||
// The request ran successfully
|
||||
let span_status = get_span_status(response.status());
|
||||
let status_code = response.status().as_u16() as i64;
|
||||
let user_agent = get_header_value("user_agent", response.headers());
|
||||
if let Some(span_status) = span_status {
|
||||
request_span.record("otel.status_code", &span_status);
|
||||
// Run the request
|
||||
let outcome = next.run(req, extensions).await;
|
||||
match &outcome {
|
||||
Ok(response) => {
|
||||
// The request ran successfully
|
||||
let span_status = get_span_status(response.status());
|
||||
let status_code = response.status().as_u16() as i64;
|
||||
let user_agent = get_header_value("user_agent", response.headers());
|
||||
if let Some(span_status) = span_status {
|
||||
request_span.record("otel.status_code", &span_status);
|
||||
}
|
||||
request_span.record("http.status_code", &status_code);
|
||||
request_span.record("http.user_agent", &user_agent.as_str());
|
||||
}
|
||||
request_span.record("http.status_code", &status_code);
|
||||
request_span.record("http.user_agent", &user_agent.as_str());
|
||||
}
|
||||
Err(e) => {
|
||||
// The request didn't run successfully
|
||||
let error_message = e.to_string();
|
||||
let error_cause_chain = format!("{:?}", e);
|
||||
request_span.record("otel.status_code", &"ERROR");
|
||||
request_span.record("error.message", &error_message.as_str());
|
||||
request_span.record("error.cause_chain", &error_cause_chain.as_str());
|
||||
if let Error::Reqwest(e) = e {
|
||||
request_span.record(
|
||||
"http.status_code",
|
||||
&e.status()
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "".to_string())
|
||||
.as_str(),
|
||||
);
|
||||
Err(e) => {
|
||||
// The request didn't run successfully
|
||||
let error_message = e.to_string();
|
||||
let error_cause_chain = format!("{:?}", e);
|
||||
request_span.record("otel.status_code", &"ERROR");
|
||||
request_span.record("error.message", &error_message.as_str());
|
||||
request_span.record("error.cause_chain", &error_cause_chain.as_str());
|
||||
if let Error::Reqwest(e) = e {
|
||||
request_span.record(
|
||||
"http.status_code",
|
||||
&e.status()
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "".to_string())
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
outcome
|
||||
}
|
||||
outcome
|
||||
.instrument(request_span.clone())
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
fn get_header_value(key: &str, headers: &HeaderMap) -> String {
|
||||
let header_default = &HeaderValue::from_static("");
|
||||
format!("{:?}", headers.get(key).unwrap_or(header_default)).replace("\"", "")
|
||||
format!("{:?}", headers.get(key).unwrap_or(header_default)).replace('"', "")
|
||||
}
|
||||
|
||||
/// HTTP Mapping <https://github.com/open-telemetry/opentelemetry-specification/blob/c4b7f4307de79009c97b3a98563e91fee39b7ba3/work_in_progress/opencensus/HTTP.md#status>
|
||||
|
|
|
@ -38,9 +38,8 @@ use opentelemetry::propagation::Injector;
|
|||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// Injects the given OpenTelemetry Context into a reqwest::Request headers to allow propagation downstream.
|
||||
pub fn inject_opentelemetry_context_into_request(span: &Span, request: Request) -> Request {
|
||||
let context = span.context();
|
||||
let mut request = request;
|
||||
pub fn inject_opentelemetry_context_into_request(mut request: Request) -> Request {
|
||||
let context = Span::current().context();
|
||||
|
||||
global::get_text_map_propagator(|injector| {
|
||||
injector.inject_context(&context, &mut RequestCarrier::new(&mut request))
|
||||
|
@ -72,3 +71,64 @@ impl<'a> Injector for RequestCarrier<'a> {
|
|||
self.request.headers_mut().insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::TracingMiddleware;
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use reqwest_middleware::ClientBuilder;
|
||||
use tracing::{info_span, Instrument, Level};
|
||||
#[cfg(any(
|
||||
feature = "opentelemetry_0_13",
|
||||
feature = "opentelemetry_0_14",
|
||||
feature = "opentelemetry_0_15"
|
||||
))]
|
||||
use tracing_subscriber_0_2::{filter, layer::SubscriberExt, Registry};
|
||||
#[cfg(not(any(
|
||||
feature = "opentelemetry_0_13",
|
||||
feature = "opentelemetry_0_14",
|
||||
feature = "opentelemetry_0_15"
|
||||
)))]
|
||||
use tracing_subscriber_0_3::{filter, layer::SubscriberExt, Registry};
|
||||
use wiremock::{matchers::any, Mock, MockServer, ResponseTemplate};
|
||||
|
||||
#[tokio::test]
|
||||
async fn tracing_middleware_propagates_otel_data_even_when_the_span_is_disabled() {
|
||||
let tracer = opentelemetry::sdk::export::trace::stdout::new_pipeline()
|
||||
.with_writer(std::io::sink())
|
||||
.install_simple();
|
||||
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
|
||||
let subscriber = Registry::default()
|
||||
.with(filter::Targets::new().with_target("reqwest_tracing::otel::test", Level::DEBUG))
|
||||
.with(telemetry);
|
||||
tracing::subscriber::set_global_default(subscriber).unwrap();
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
|
||||
// Mock server - sends all request headers back in the response
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(any())
|
||||
.respond_with(|req: &wiremock::Request| {
|
||||
req.headers
|
||||
.iter()
|
||||
.fold(ResponseTemplate::new(200), |resp, (k, v)| {
|
||||
resp.append_header(k.clone(), v.clone())
|
||||
})
|
||||
})
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = ClientBuilder::new(reqwest::Client::new())
|
||||
.with(TracingMiddleware)
|
||||
.build();
|
||||
|
||||
let resp = client
|
||||
.get(server.uri())
|
||||
.send()
|
||||
.instrument(info_span!("some_span"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(resp.headers().contains_key("traceparent"));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue