Back to blog home

Graph Neural Networks and Generative AI

Graph neural networks (GNNs) have been foundational to many AI applications across industries, from drug discovery to social networks to product recommendations. But the recent surge of innovation in generative AI has led many ML teams to question how they can incorporate GNNs in their generative AI applications. Stanford professor and co-founder at Kumo.AI, Jure Leskovec, joined us on AI Explained to explore the intersection of graph neural networks, knowledge graphs, and generative AI, and how organizations can incorporate GNNs in their generative AI initiatives. Watch the webinar on-demand now and check out three key takeaways below.

But really, watch the webinar. You don’t want to miss this discussion!

Businesses of all types can leverage GNNs

Many businesses, aside from social network companies, mistakenly believe they don't possess graphs. But, in reality, most organizations have graphs due to their data residing in relational databases. These databases, comprising tables of data, are currently manually joined for ML tasks. This process doesn't fully harness the rich data connections available and often leads to varied approaches by MLOps teams, sometimes driven by personal bias rather than optimal data utilization.

Additionally, the term "tabular data" is commonly misinterpreted to mean a single table, yet in real-world applications, multiple tables are more prevalent. The main challenge in data science is transitioning from these multiple tables to one, requiring ML practitioners to do feature engineering, a very resource intensive and time consuming process. Traditional feature engineering can lead to data loss and errors, whereas GNNs provide an end-to-end solution, making direct predictions without discarding data. These GNNs can harness signals even from data that's multiple tables away, offering a breakthrough in representation learning on multi-tabular data.

GNNs are versatile

GNNs offer versatile applications across a broad spectrum of industries, owing to their ability to handle diverse graph structures. They excel in predicting individual entities (like forecasting sales volume), linking predictions (such as brand affinity and recommendations), and making overarching graph-level assessments, notably in determining molecular properties or detecting fraud. Furthermore, GNNs can learn from an entity's own time series while also harnessing information from correlated time series. This adaptability and comprehensive analytical capability make GNNs a powerful tool across various domains.

The depth and design of GNNs largely depend on the specific domain or use case. While depth can refer to the neural network layers or the depth within the graph itself, it's essential to differentiate between the two. For example, delving too deep in social networks might result in over-smoothing; it's crucial to balance depth with the expressiveness of individual layers. Effective GNN structures often combine pre-processing, message passing, and post-processing layers. The inherent structure of the graph, such as long molecules in biological contexts, can necessitate deeper networks for comprehensive information propagation.

GNNs for retrieval augmented generation in generative AI

Foundation models, such as large language models, are pre-trained on extensive datasets, allowing them to possess broad common knowledge. However, there's a growing emphasis on creating domain-specific foundation models for areas like biology and medicine. These models are evolving to be multimodal, encompassing various data types like images, text, and structured data. GNNs play a pivotal role in generative AI as companies utilize their knowledge bases, or their private data, to deliver more effective domain-specific models. The private data is essentially stored in relational tables, where some of the text can be used for retrieval augmented generation (RAG) — a popular LLM deployment option to launch domain-specific AI systems — to enhance the real-time accuracy and relevance of domain-specific AI systems.